Skip to content

fix(autotuner): hybrid bucket spacing and cache-key fix for fused MoE#3063

Closed
StudyingShao wants to merge 3 commits intoflashinfer-ai:mainfrom
StudyingShao:jiangs/autotuner-hybrid-bucket-spacing
Closed

fix(autotuner): hybrid bucket spacing and cache-key fix for fused MoE#3063
StudyingShao wants to merge 3 commits intoflashinfer-ai:mainfrom
StudyingShao:jiangs/autotuner-hybrid-bucket-spacing

Conversation

@StudyingShao
Copy link
Copy Markdown
Contributor

@StudyingShao StudyingShao commented Apr 14, 2026

📌 Description

Power-of-2 tuning buckets create exponentially growing gaps at large num_tokens values (e.g. 1024 gap between bucket
1024 and 2048). For MoE workloads, avg_tokens_per_expert can cross multiple tile selection boundaries within a single
gap, causing non-power-of-2 inputs (like 1536) to be mapped to a much smaller bucket and receive a kernel optimized for
the wrong workload size — sometimes running slower than larger batch sizes.

This PR replaces pure power-of-2 buckets with a four-phase hybrid spacing scheme: power-of-2 up to 256, linear step 256
up to 2048, linear step 512 up to 4096, then power-of-2 again. This keeps bucket count manageable (21 vs 14 for
max=8192) while ensuring adjacent buckets never span a tile selection boundary.

This PR also fixes a bug in AutoTuner._find_nearest_profile where only the first tensor in a multi-tensor
DynamicTensorSpec had its dimension mapped to the nearest bucket. For ops like trtllm_fp8_per_tensor_scale_moe that
declare 5 input tensors sharing the same num_tokens dimension, the remaining 4 tensors retained their original shape —
producing a different cache key at inference time than what was stored during profiling. This caused a silent cache
miss and fallback to the default tactic for any non-bucket-boundary token count (e.g. bs=384 mapped to bucket 512 but
always fell back to TileN=8).

🔍 Related Issues

N/A

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are
complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

  • The bucket spacing issue was observed with trtllm-gen fused MoE FP8-Per-Tensor (num_experts=24, top_k=8):
    num_tokens=1536 was mapped to bucket 1024, selecting a small-tile kernel (t128x32x256) instead of the large-tile
    kernel (t128x128x128) that 2048 gets — making 1536 slower than 2048.
  • The cache-miss bug compounded this: even for the hybrid-bucket case, bs=384 mapped correctly to bucket 512 but always
    cache-missed because _find_nearest_profile only bucketed the output buffer tensor, leaving the other 4 input tensors
    at their original 384-token shapes. After fixing both, bs=384 goes from 1.00x to 2.40x speedup (Fp8-Per-Tensor,
    num_experts=192, top_k=8, hidden=4096, intermediate=1536).
  • The hybrid bucket fix applies to CUTLASS MoE, CuTE DSL MoE, GEMM, and low-latency GEMM paths for consistency, though
    only the trtllm-gen path has been benchmarked.
  • The cache-miss fix applies to any op whose DynamicTensorSpec lists more than one input tensor sharing a dynamic
    dimension.
  • Phase thresholds (256/2048/4096) and step sizes (256/512) are tuned for current max tile sizes (128/256). If future
    kernels introduce larger tiles, these may need adjustment.
  • Old functions get_power_of_2_num_tokens_buckets and get_last_power_of_2_num_tokens_buckets are removed;
    last_positive_power_of_2 is retained as it's still used by autotuner unit tests.

aleozlx and others added 2 commits March 10, 2026 15:07
…ed MoE

Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 14, 2026

📝 Walkthrough

Walkthrough

The PR replaces a power-of-2 token-bucketing scheme with a new hybrid bucketing approach across multiple modules. New bucket generation and mapping functions are added in flashinfer/fused_moe/utils.py, and MoE, GEMM, and TRTLLM integrations are updated to use the hybrid scheme. Version bumped to 0.6.6.

Changes

Cohort / File(s) Summary
Token Bucketing Scheme
flashinfer/fused_moe/utils.py
Added hybrid bucketing: _PHASE* constants, _ceil_to_step, get_hybrid_num_tokens_buckets(max_num_tokens, min_num_tokens=1), map_to_hybrid_bucket(x, max_num_tokens), map_to_hybrid_bucket_uncapped(x); removed older power-of-2 bucket helpers.
MoE Core & Tuner
flashinfer/fused_moe/core.py, flashinfer/fused_moe/cute_dsl/tuner.py
Replaced get_last_power_of_2_num_tokens_buckets/last_positive_power_of_2 usage with get_hybrid_num_tokens_buckets and map_to_hybrid_bucket in DynamicTensorSpec and tuning-config-related mappings.
GEMM & TRTLLM
flashinfer/gemm/gemm_base.py, flashinfer/trtllm_low_latency_gemm.py
Replaced power-of-2 bucket generator and mapping with get_hybrid_num_tokens_buckets and map_to_hybrid_bucket_uncapped for DynamicTensorSpec instances used in tuning.
Autotuner behavior
flashinfer/autotuner.py
Added conditional debug tracing via env FLASHINFER_AUTOTUNER_TRACE=1; changed _find_nearest_profile to map a spec once and assign mapped value to all (input,dim) pairs.
Version
version.txt
Bumped version from 0.6.5 to 0.6.6.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • jiahanc
  • nv-yunzheq
  • IwakuraRein
  • samuellees

Poem

🐰 I hopped through buckets, old and new,
Phase by phase I lined the queue,
Power and linear steps in tune,
Tokens sorted under moon,
Hooray — the hybrid finds its view! 🎉

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description check ✅ Passed The pull request description is comprehensive and well-structured, following the template with all required sections completed.
Title check ✅ Passed The title accurately reflects the main changes: replacing power-of-2 autotuner buckets with hybrid bucket spacing and fixing cache-key/profile-finding logic for fused MoE workloads.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request replaces the power-of-2 token bucketing strategy with a hybrid approach across MoE and GEMM modules. The new implementation introduces adaptive spacing using four phases—combining power-of-2 and linear steps—to provide finer granularity for MoE workloads and prevent performance degradation caused by large gaps between buckets. Additionally, the project version has been bumped to 0.6.6. I have no feedback to provide as there were no review comments to evaluate.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
flashinfer/fused_moe/utils.py (1)

219-223: Replace × in docstring to satisfy Ruff RUF002.
Use plain x in the phase descriptions to avoid ambiguous Unicode lint warnings.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/utils.py` around lines 219 - 223, The docstring in
flashinfer.fused_moe.utils that lists the Phase descriptions uses the Unicode
multiplication sign '×', which triggers Ruff RUF002; replace each '×' with a
plain ASCII 'x' in that docstring (the lines "Phase 1... (step ×2)" and "Phase
4... (step ×2)") so the phase descriptions read "step x2" to satisfy the linter.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/fused_moe/utils.py`:
- Around line 227-248: The bucket-generation logic in the function ignores
min_num_tokens for phases 2–4 so buckets smaller than min_num_tokens can be
emitted; update the phase start values to respect min_num_tokens by initializing
each phase's m to the maximum of min_num_tokens and the existing phase-start
value (e.g., for Phase 2 set m = max(min_num_tokens, _PHASE1_END +
_PHASE2_STEP), for Phase 3 set m = max(min_num_tokens, _PHASE2_END +
_PHASE3_STEP), and for Phase 4 set m = max(min_num_tokens, _PHASE3_END * 2)) and
keep the existing upper-bound checks (min(max_num_tokens, _PHASEx_END)) and
increments so no bucket below min_num_tokens is appended; apply this change
around the loops that build buckets (the blocks using variables m, _PHASE1_END,
_PHASE2_STEP, _PHASE2_END, _PHASE3_STEP, and _PHASE3_END).

---

Nitpick comments:
In `@flashinfer/fused_moe/utils.py`:
- Around line 219-223: The docstring in flashinfer.fused_moe.utils that lists
the Phase descriptions uses the Unicode multiplication sign '×', which triggers
Ruff RUF002; replace each '×' with a plain ASCII 'x' in that docstring (the
lines "Phase 1... (step ×2)" and "Phase 4... (step ×2)") so the phase
descriptions read "step x2" to satisfy the linter.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 8f753193-443e-4643-a92a-b8a1fe2527c7

📥 Commits

Reviewing files that changed from the base of the PR and between 00ac505 and d5f4e8d.

📒 Files selected for processing (6)
  • flashinfer/fused_moe/core.py
  • flashinfer/fused_moe/cute_dsl/tuner.py
  • flashinfer/fused_moe/utils.py
  • flashinfer/gemm/gemm_base.py
  • flashinfer/trtllm_low_latency_gemm.py
  • version.txt

Comment on lines +227 to +248
m = max(min_num_tokens, 1)
while m <= min(max_num_tokens, _PHASE1_END):
buckets.append(m)
m *= 2

# Phase 2: linear step 256 in (_PHASE1_END, _PHASE2_END]
m = _PHASE1_END + _PHASE2_STEP
while m <= min(max_num_tokens, _PHASE2_END):
buckets.append(m)
m += _PHASE2_STEP

# Phase 3: linear step 512 in (_PHASE2_END, _PHASE3_END]
m = _PHASE2_END + _PHASE3_STEP
while m <= min(max_num_tokens, _PHASE3_END):
buckets.append(m)
m += _PHASE3_STEP

# Phase 4: power-of-2 beyond _PHASE3_END
m = _PHASE3_END * 2
while m <= max_num_tokens:
buckets.append(m)
m *= 2
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

min_num_tokens is not honored after phase 1.
For min_num_tokens > 256, phases 2–4 can still emit buckets smaller than min_num_tokens (e.g., 512), which violates the function contract and can under-bucket profiles.

Suggested fix
 def get_hybrid_num_tokens_buckets(
     max_num_tokens: int, min_num_tokens: int = 1
 ) -> Tuple[int, ...]:
+    if max_num_tokens < 1:
+        raise ValueError("max_num_tokens must be >= 1")
+
     buckets: List[int] = []
 
     # Phase 1: power-of-2 up to _PHASE1_END
-    m = max(min_num_tokens, 1)
+    min_num_tokens = max(min_num_tokens, 1)
+    m = next_positive_power_of_2(min_num_tokens)
     while m <= min(max_num_tokens, _PHASE1_END):
         buckets.append(m)
         m *= 2
 
     # Phase 2: linear step 256 in (_PHASE1_END, _PHASE2_END]
-    m = _PHASE1_END + _PHASE2_STEP
+    m = _ceil_to_step(max(min_num_tokens, _PHASE1_END + 1), _PHASE2_STEP)
     while m <= min(max_num_tokens, _PHASE2_END):
         buckets.append(m)
         m += _PHASE2_STEP
 
     # Phase 3: linear step 512 in (_PHASE2_END, _PHASE3_END]
-    m = _PHASE2_END + _PHASE3_STEP
+    m = _ceil_to_step(max(min_num_tokens, _PHASE2_END + 1), _PHASE3_STEP)
     while m <= min(max_num_tokens, _PHASE3_END):
         buckets.append(m)
         m += _PHASE3_STEP
 
     # Phase 4: power-of-2 beyond _PHASE3_END
-    m = _PHASE3_END * 2
+    m = next_positive_power_of_2(max(min_num_tokens, _PHASE3_END + 1))
     while m <= max_num_tokens:
         buckets.append(m)
         m *= 2
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/utils.py` around lines 227 - 248, The bucket-generation
logic in the function ignores min_num_tokens for phases 2–4 so buckets smaller
than min_num_tokens can be emitted; update the phase start values to respect
min_num_tokens by initializing each phase's m to the maximum of min_num_tokens
and the existing phase-start value (e.g., for Phase 2 set m =
max(min_num_tokens, _PHASE1_END + _PHASE2_STEP), for Phase 3 set m =
max(min_num_tokens, _PHASE2_END + _PHASE3_STEP), and for Phase 4 set m =
max(min_num_tokens, _PHASE3_END * 2)) and keep the existing upper-bound checks
(min(max_num_tokens, _PHASEx_END)) and increments so no bucket below
min_num_tokens is appended; apply this change around the loops that build
buckets (the blocks using variables m, _PHASE1_END, _PHASE2_STEP, _PHASE2_END,
_PHASE3_STEP, and _PHASE3_END).

…file

Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/autotuner.py`:
- Around line 783-787: The change reintroduces linked-dimension propagation by
iterating with zip(spec.input_idx, spec.dim_idx, strict=True) and assigning
mapped_value to every linked slot (using mapped_value and base_profile), which
reopens a known TRTLLM fused MoE crash; revert to assigning only the first
linked dimension: remove the for-loop that writes mapped_value to all indices
and instead set only base_profile[spec.input_idx[0]][spec.dim_idx[0]] to
spec.map_to_tuning_buckets(base_profile[spec.input_idx[0]][spec.dim_idx[0]]),
keeping the call to spec.map_to_tuning_buckets but preventing propagation across
spec.input_idx/spec.dim_idx.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 7f24f724-8a97-4155-a787-8e14a6d594e2

📥 Commits

Reviewing files that changed from the base of the PR and between d5f4e8d and 37b425c.

📒 Files selected for processing (1)
  • flashinfer/autotuner.py

Comment thread flashinfer/autotuner.py
Comment on lines +783 to +787
mapped_value = spec.map_to_tuning_buckets(
base_profile[spec.input_idx[0]][spec.dim_idx[0]]
)
for i_idx, d_idx in zip(spec.input_idx, spec.dim_idx, strict=True):
base_profile[i_idx][d_idx] = mapped_value
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify whether the C++ fix and test re-enablement for linked-dimension mapping landed in this branch.

set -euo pipefail

echo "== _find_nearest_profile mapping behavior =="
rg -n -C3 '_find_nearest_profile|zip\(spec\.input_idx, spec\.dim_idx|map_to_tuning_buckets' flashinfer/autotuner.py

echo
echo "== Known skipped tests status =="
rg -n -C2 'pytest\.mark\.skip|_find_nearest_profile|linked dimension|DynamicTensorSpec' tests/autotuner/test_autotuner_core.py || true

echo
echo "== TRTLLM launcher key lookup paths =="
rg -n -C3 'launchers_map\.at\(tile_N\)|tile_N|launcher' csrc/trtllm_fused_moe_kernel_launcher.cu || true

Repository: flashinfer-ai/flashinfer

Length of output: 22630


Reintroduces a known linked-dimension crash path that was reverted in PR #2697.

Lines 783–787 propagate the mapped value across all linked dimensions via zip(spec.input_idx, spec.dim_idx). This re-enables behavior that was intentionally reverted because it triggers TRTLLM fused MoE C++ runtime crashes (launchers_map.at(tile_N) missing key). The three regression tests that verify this mapping (test_find_nearest_profile_moe_shared_num_tokens_axis, test_find_nearest_profile_moe_same_bucket_same_profile, test_find_nearest_profile_maps_all_linked_dims) remain skipped with messages stating the propagation was reverted. Unless the corresponding C++ fix is included in this commit, revert to first-dimension-only mapping:

base_profile[spec.input_idx[0]][spec.dim_idx[0]] = spec.map_to_tuning_buckets(
    base_profile[spec.input_idx[0]][spec.dim_idx[0]]
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/autotuner.py` around lines 783 - 787, The change reintroduces
linked-dimension propagation by iterating with zip(spec.input_idx, spec.dim_idx,
strict=True) and assigning mapped_value to every linked slot (using mapped_value
and base_profile), which reopens a known TRTLLM fused MoE crash; revert to
assigning only the first linked dimension: remove the for-loop that writes
mapped_value to all indices and instead set only
base_profile[spec.input_idx[0]][spec.dim_idx[0]] to
spec.map_to_tuning_buckets(base_profile[spec.input_idx[0]][spec.dim_idx[0]]),
keeping the call to spec.map_to_tuning_buckets but preventing propagation across
spec.input_idx/spec.dim_idx.

@StudyingShao StudyingShao changed the title fix: Replace power-of-2 autotuner buckets with hybrid spacing for fused MoE fix(autotuner): hybrid bucket spacing and cache-key fix for fused MoE Apr 16, 2026
@StudyingShao
Copy link
Copy Markdown
Contributor Author

Closed. Move to #3115 .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants