Skip to content

perf(autotuner): replace power-of-2 token buckets with hybrid spacing & fix missing routing_replay_out arg#3115

Open
StudyingShao wants to merge 3 commits intoflashinfer-ai:mainfrom
StudyingShao:jiangs/autotuner-hybrid-bucket-spacing-main
Open

perf(autotuner): replace power-of-2 token buckets with hybrid spacing & fix missing routing_replay_out arg#3115
StudyingShao wants to merge 3 commits intoflashinfer-ai:mainfrom
StudyingShao:jiangs/autotuner-hybrid-bucket-spacing-main

Conversation

@StudyingShao
Copy link
Copy Markdown

@StudyingShao StudyingShao commented Apr 18, 2026

📌 Description

This PR includes two improvements:

  1. perf(autotuner): Replace power-of-2 token buckets with hybrid spacing — Pure power-of-2 spacing creates huge gaps at large values (e.g. a jump from 1024 to 2048), forcing the autotuner to pick a kernel optimised for a very different workload size. The new hybrid scheme uses four phases with progressively coarser spacing:

    • Phase 1: [min .. 256] — power-of-2 (step ×2)
    • Phase 2: (256 .. 2048] — linear step 256
    • Phase 3: (2048 .. 4096] — linear step 512
    • Phase 4: (4096 .. max] — power-of-2 (step ×2)

    All callsites in MoE, GEMM, and low-latency GEMM autotuners are updated to use the new get_hybrid_num_tokens_buckets / map_to_hybrid_bucket API.

  2. fix: Pass missing routing_replay_out arg to trtllm_fp8_per_tensor_scale_moe — Two call sites in fused_moe/core.py were missing the routing_replay_out argument, causing it to be silently dropped.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Changed files:

  • flashinfer/fused_moe/utils.py — Core implementation: new get_hybrid_num_tokens_buckets, map_to_hybrid_bucket, map_to_hybrid_bucket_uncapped; removed old get_last_power_of_2_num_tokens_buckets
  • flashinfer/fused_moe/core.py — Updated all MoE autotuner callsites + added missing routing_replay_out arg
  • flashinfer/fused_moe/cute_dsl/tuner.py — Updated CuTe DSL FP4 MoE tuner callsite
  • flashinfer/gemm/gemm_base.py — Updated GEMM (FP8, BF16, FP4, MXFP8, TGV) autotuner configs
  • flashinfer/trtllm_low_latency_gemm.py — Updated low-latency GEMM autotuner config

Summary by CodeRabbit

  • Improvements

    • Updated autotuning bucketing to a hybrid token-bucketing scheme, producing more representative dynamic profiles for MoE and GEMM tuning.
    • More accurate mapping of dynamic token counts into tuning buckets, improving kernel selection and performance profiling.
  • New Features

    • FP8 MoE execution now forwards an optional routing replay output through the execution path, enabling downstream routing diagnostics.

…le_moe

Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 18, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: a8ee289c-8ad1-404d-b6e0-581a5ade33b2

📥 Commits

Reviewing files that changed from the base of the PR and between 5027cbf and ee52bd9.

📒 Files selected for processing (1)
  • flashinfer/gemm/gemm_base.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • flashinfer/gemm/gemm_base.py

📝 Walkthrough

Walkthrough

Replaces power-of-2 token-bucketing with a new four-phase hybrid bucketing across MoE and GEMM autotuning; updates callsites to use the new mapping utilities. Also threads an optional routing_replay_out tensor from the runner into the FP8 per-tensor MoE C++ kernel invocation.

Changes

Cohort / File(s) Summary
Token-Bucketing Core Utilities
flashinfer/fused_moe/utils.py
Removed power-of-2 bucket helpers; added four-phase get_hybrid_num_tokens_buckets, map_to_hybrid_bucket, map_to_hybrid_bucket_uncapped, phase constants, and _ceil_to_step.
MoE Autotuning & Runtime
flashinfer/fused_moe/core.py, flashinfer/fused_moe/cute_dsl/tuner.py
Switched DynamicTensorSpec token-dimension bucketing to hybrid functions. Extended FP8 per-tensor MoE execution path to forward optional routing_replay_out into the trtllm_fp8_per_tensor_scale_moe_op call.
GEMM Autotuning Updates
flashinfer/gemm/gemm_base.py, flashinfer/trtllm_low_latency_gemm.py
Replaced power-of-2 token bucketing with hybrid bucketing and uncapped mapping for various FP8/TGV/MXFP8 GEMM dynamic specs.

Sequence Diagram(s)

sequenceDiagram
    participant Client as Client/Caller
    participant Tuner as Autotuner/Tuner
    participant Utils as Bucketing Utils
    participant Runner as MoE Runner
    participant Kernel as trtllm_fp8_per_tensor_scale_moe_op

    Client->>Tuner: request tuning / forward(input with num_tokens)
    Tuner->>Utils: get_hybrid_num_tokens_buckets(max_tokens)
    Tuner->>Utils: map_to_hybrid_bucket(num_tokens, max_tokens)
    Tuner->>Runner: select tactic / provide mapped bucket
    Client->>Runner: forward(..., routing_replay_out=?)
    Runner->>Kernel: call trtllm_fp8_per_tensor_scale_moe_op(..., routing_replay_out)
    Kernel-->>Runner: result
    Runner-->>Client: output
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~28 minutes

Possibly related PRs

Suggested reviewers

  • yzh119
  • nv-yunzheq
  • aleozlx
  • sricketts
  • bkryu
  • jiahanc

Poem

🐰
Buckets braided, four-phase bright,
Tokens hop in measured flight.
Rerouted threads and tuned delight,
Runner hums through day and night,
A rabbit cheers: "Profiles take flight!" 🥕

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 40.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes both main changes: replacing power-of-2 token buckets with hybrid spacing and fixing the missing routing_replay_out argument.
Description check ✅ Passed The description is comprehensive and complete, covering both improvements with detailed explanations, the hybrid spacing scheme phases, affected files, and checklist items marked as completed.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request replaces the power-of-2 token bucket generation logic with a hybrid approach across several modules to improve autotuning for MoE workloads. The new logic uses four phases with varying spacing, including power-of-2 and linear steps. Additionally, a routing_replay_out parameter is added to the MoE forward functions. A logic inconsistency was identified in get_hybrid_num_tokens_buckets where the generated buckets may not align with the mapping function when min_num_tokens is greater than one, which could lead to autotuner failures.

Comment on lines +227 to +253
m = max(min_num_tokens, 1)
while m <= min(max_num_tokens, _PHASE1_END):
buckets.append(m)
m *= 2

# Phase 2: linear step 256 in (_PHASE1_END, _PHASE2_END]
m = _PHASE1_END + _PHASE2_STEP
while m <= min(max_num_tokens, _PHASE2_END):
buckets.append(m)
m += _PHASE2_STEP

# Phase 3: linear step 512 in (_PHASE2_END, _PHASE3_END]
m = _PHASE2_END + _PHASE3_STEP
while m <= min(max_num_tokens, _PHASE3_END):
buckets.append(m)
m += _PHASE3_STEP

# Phase 4: power-of-2 beyond _PHASE3_END
m = _PHASE3_END * 2
while m <= max_num_tokens:
buckets.append(m)
m *= 2

if not buckets or buckets[-1] != max_num_tokens:
buckets.append(max_num_tokens)

return tuple(sorted(set(buckets)))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The implementation of get_hybrid_num_tokens_buckets has a critical inconsistency with map_to_hybrid_bucket when min_num_tokens > 1.

  1. Phase 1 Mismatch: If min_num_tokens is not a power of 2 (e.g., 10), Phase 1 currently generates buckets starting from that value (e.g., [10, 20, 40, ...]). However, map_to_hybrid_bucket uses next_positive_power_of_2(x) for Phase 1, which means an input of size 10 will map to bucket 16. Since 16 is not in the generated list, the autotuner will fail to find a tuned tactic for this size.
  2. Phase 2-4 Filtering: The loops for subsequent phases use fixed starting points (e.g., _PHASE1_END + _PHASE2_STEP), which results in buckets smaller than min_num_tokens being added to the list if min_num_tokens is large.

The robust fix is to always generate the full set of potential buckets starting from 1 (to ensure consistency with the mapping logic) and then filter the final result to keep only those within the [min_num_tokens, max_num_tokens] range.

    buckets: List[int] = []

    # Phase 1: power-of-2 up to _PHASE1_END
    m = 1
    while m <= min(max_num_tokens, _PHASE1_END):
        buckets.append(m)
        m *= 2

    # Phase 2: linear step 256 in (_PHASE1_END, _PHASE2_END]
    m = _PHASE1_END + _PHASE2_STEP
    while m <= min(max_num_tokens, _PHASE2_END):
        buckets.append(m)
        m += _PHASE2_STEP

    # Phase 3: linear step 512 in (_PHASE2_END, _PHASE3_END]
    m = _PHASE2_END + _PHASE3_STEP
    while m <= min(max_num_tokens, _PHASE3_END):
        buckets.append(m)
        m += _PHASE3_STEP

    # Phase 4: power-of-2 beyond _PHASE3_END
    m = _PHASE3_END * 2
    while m <= max_num_tokens:
        buckets.append(m)
        m *= 2

    if not buckets or buckets[-1] != max_num_tokens:
        buckets.append(max_num_tokens)

    return tuple(sorted(set(b for b in buckets if b >= min_num_tokens and b <= max_num_tokens)))

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/fused_moe/utils.py`:
- Around line 217-223: The docstring in fused_moe/utils.py (around the function
that describes the four phases) contains Unicode multiplication characters "×"
which trigger Ruff; replace those with the ASCII letter "x" (e.g., change "step
×2" to "step x2") so the docstring uses plain ASCII and pre-commit passes;
update all occurrences in that docstring text accordingly.
- Around line 224-253: get_hybrid_num_tokens_buckets is not honoring
min_num_tokens across phases: phase1 starts at min_num_tokens without rounding
up to the next power-of-2, and phases 2/3 always start at fixed boundaries
(e.g., _PHASE1_END+_PHASE2_STEP) which can emit values below min_num_tokens. Fix
by computing phase starts relative to min_num_tokens: for phase1 set m to the
smallest power-of-2 >= min_num_tokens (use bit math or loop) and then multiply
by 2; for phase2 set m to the smallest value >= min_num_tokens and >=
(_PHASE1_END+_PHASE2_STEP) that aligns to the _PHASE2_STEP grid (ceil to next
multiple of _PHASE2_STEP); for phase3 do the same alignment with _PHASE3_STEP
and _PHASE2_END; and for phase4 start at max(min_num_tokens, _PHASE3_END*2) then
multiply by 2; ensure every appended bucket >= min_num_tokens and <=
max_num_tokens and keep the final sorting/unique logic intact (variables:
get_hybrid_num_tokens_buckets, min_num_tokens, max_num_tokens, _PHASE1_END,
_PHASE2_STEP, _PHASE2_END, _PHASE3_STEP, _PHASE3_END).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f83d9649-3a2d-43ac-98e8-a7f2b490a9f6

📥 Commits

Reviewing files that changed from the base of the PR and between 7d0f68e and 5027cbf.

📒 Files selected for processing (5)
  • flashinfer/fused_moe/core.py
  • flashinfer/fused_moe/cute_dsl/tuner.py
  • flashinfer/fused_moe/utils.py
  • flashinfer/gemm/gemm_base.py
  • flashinfer/trtllm_low_latency_gemm.py

Comment on lines +217 to +223
This function uses four phases with progressively coarser spacing::

Phase 1: [min .. 256] — power-of-2 (step ×2)
Phase 2: (256 .. 2048] — linear step 256
Phase 3: (2048 .. 4096] — linear step 512
Phase 4: (4096 .. max] — power-of-2 (step ×2)
"""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Replace ambiguous multiplication signs in the docstring.

Ruff flags the Unicode × characters here; use plain x to keep pre-commit clean.

Proposed fix
-        Phase 1:  [min .. 256]   — power-of-2    (step ×2)
+        Phase 1:  [min .. 256]   — power-of-2    (step x2)
         Phase 2:  (256 .. 2048]  — linear step 256
         Phase 3:  (2048 .. 4096] — linear step 512
-        Phase 4:  (4096 .. max]  — power-of-2    (step ×2)
+        Phase 4:  (4096 .. max]  — power-of-2    (step x2)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
This function uses four phases with progressively coarser spacing::
Phase 1: [min .. 256] — power-of-2 (step ×2)
Phase 2: (256 .. 2048] — linear step 256
Phase 3: (2048 .. 4096] — linear step 512
Phase 4: (4096 .. max] — power-of-2 (step ×2)
"""
This function uses four phases with progressively coarser spacing::
Phase 1: [min .. 256] — power-of-2 (step x2)
Phase 2: (256 .. 2048] — linear step 256
Phase 3: (2048 .. 4096] — linear step 512
Phase 4: (4096 .. max] — power-of-2 (step x2)
"""
🧰 Tools
🪛 Ruff (0.15.10)

[warning] 219-219: Docstring contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?

(RUF002)


[warning] 222-222: Docstring contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?

(RUF002)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/utils.py` around lines 217 - 223, The docstring in
fused_moe/utils.py (around the function that describes the four phases) contains
Unicode multiplication characters "×" which trigger Ruff; replace those with the
ASCII letter "x" (e.g., change "step ×2" to "step x2") so the docstring uses
plain ASCII and pre-commit passes; update all occurrences in that docstring text
accordingly.

Comment on lines +224 to +253
buckets: List[int] = []

# Phase 1: power-of-2 up to _PHASE1_END
m = max(min_num_tokens, 1)
while m <= min(max_num_tokens, _PHASE1_END):
buckets.append(m)
m *= 2

# Phase 2: linear step 256 in (_PHASE1_END, _PHASE2_END]
m = _PHASE1_END + _PHASE2_STEP
while m <= min(max_num_tokens, _PHASE2_END):
buckets.append(m)
m += _PHASE2_STEP

# Phase 3: linear step 512 in (_PHASE2_END, _PHASE3_END]
m = _PHASE2_END + _PHASE3_STEP
while m <= min(max_num_tokens, _PHASE3_END):
buckets.append(m)
m += _PHASE3_STEP

# Phase 4: power-of-2 beyond _PHASE3_END
m = _PHASE3_END * 2
while m <= max_num_tokens:
buckets.append(m)
m *= 2

if not buckets or buckets[-1] != max_num_tokens:
buckets.append(max_num_tokens)

return tuple(sorted(set(buckets)))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Honor min_num_tokens across all phases.

Line 233 always starts phase 2 at 512, so get_hybrid_num_tokens_buckets(4096, min_num_tokens=1024) still emits 512 and 768. Line 227 also starts phase 1 from arbitrary non-powers like 129, breaking the documented power-of-2 phase.

Proposed fix
     buckets: List[int] = []
+    min_num_tokens = max(min_num_tokens, 1)
+    if max_num_tokens < min_num_tokens:
+        raise ValueError("max_num_tokens must be >= min_num_tokens")
 
     # Phase 1: power-of-2 up to _PHASE1_END
-    m = max(min_num_tokens, 1)
+    m = next_positive_power_of_2(min_num_tokens)
     while m <= min(max_num_tokens, _PHASE1_END):
         buckets.append(m)
         m *= 2
 
     # Phase 2: linear step 256 in (_PHASE1_END, _PHASE2_END]
-    m = _PHASE1_END + _PHASE2_STEP
+    m = _ceil_to_step(max(min_num_tokens, _PHASE1_END + 1), _PHASE2_STEP)
     while m <= min(max_num_tokens, _PHASE2_END):
         buckets.append(m)
         m += _PHASE2_STEP
 
     # Phase 3: linear step 512 in (_PHASE2_END, _PHASE3_END]
-    m = _PHASE2_END + _PHASE3_STEP
+    m = _ceil_to_step(max(min_num_tokens, _PHASE2_END + 1), _PHASE3_STEP)
     while m <= min(max_num_tokens, _PHASE3_END):
         buckets.append(m)
         m += _PHASE3_STEP
 
     # Phase 4: power-of-2 beyond _PHASE3_END
-    m = _PHASE3_END * 2
+    m = next_positive_power_of_2(max(min_num_tokens, _PHASE3_END + 1))
     while m <= max_num_tokens:
         buckets.append(m)
         m *= 2
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/utils.py` around lines 224 - 253,
get_hybrid_num_tokens_buckets is not honoring min_num_tokens across phases:
phase1 starts at min_num_tokens without rounding up to the next power-of-2, and
phases 2/3 always start at fixed boundaries (e.g., _PHASE1_END+_PHASE2_STEP)
which can emit values below min_num_tokens. Fix by computing phase starts
relative to min_num_tokens: for phase1 set m to the smallest power-of-2 >=
min_num_tokens (use bit math or loop) and then multiply by 2; for phase2 set m
to the smallest value >= min_num_tokens and >= (_PHASE1_END+_PHASE2_STEP) that
aligns to the _PHASE2_STEP grid (ceil to next multiple of _PHASE2_STEP); for
phase3 do the same alignment with _PHASE3_STEP and _PHASE2_END; and for phase4
start at max(min_num_tokens, _PHASE3_END*2) then multiply by 2; ensure every
appended bucket >= min_num_tokens and <= max_num_tokens and keep the final
sorting/unique logic intact (variables: get_hybrid_num_tokens_buckets,
min_num_tokens, max_num_tokens, _PHASE1_END, _PHASE2_STEP, _PHASE2_END,
_PHASE3_STEP, _PHASE3_END).

@samuellees
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !570 has been created, and the CI pipeline #48974308 is currently running. I'll report back once the pipeline job completes.

@samuellees samuellees self-assigned this Apr 20, 2026
Copy link
Copy Markdown
Collaborator

@samuellees samuellees left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, waiting for the CI pass

@samuellees samuellees added run-ci and removed run-ci labels Apr 21, 2026
@samuellees
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !570 has been updated with latest changes, and the CI pipeline #49091125 is currently running. I'll report back once the pipeline job completes.

@samuellees samuellees added run-ci and removed run-ci labels Apr 21, 2026
@samuellees samuellees enabled auto-merge (squash) April 22, 2026 01:57
@samuellees
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !570 has been created, and the CI pipeline #49156002 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #49156002: 1/20 passed

@samuellees
Copy link
Copy Markdown
Collaborator

samuellees commented Apr 22, 2026

Hi @StudyingShao , Could you please:

  1. Fix the conflicts
  2. Take a look if this ci fail is relative with this PR? https://gitlab-master.nvidia.com/dl/flashinfer/flashinfer-ci/-/jobs/302286472#L2688

Thx!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants