fix(autotuner): hybrid bucket spacing and cache-key fix for fused MoE#3063
fix(autotuner): hybrid bucket spacing and cache-key fix for fused MoE#3063StudyingShao wants to merge 3 commits intoflashinfer-ai:mainfrom
Conversation
…ed MoE Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
📝 WalkthroughWalkthroughThe PR replaces a power-of-2 token-bucketing scheme with a new hybrid bucketing approach across multiple modules. New bucket generation and mapping functions are added in Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request replaces the power-of-2 token bucketing strategy with a hybrid approach across MoE and GEMM modules. The new implementation introduces adaptive spacing using four phases—combining power-of-2 and linear steps—to provide finer granularity for MoE workloads and prevent performance degradation caused by large gaps between buckets. Additionally, the project version has been bumped to 0.6.6. I have no feedback to provide as there were no review comments to evaluate.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
flashinfer/fused_moe/utils.py (1)
219-223: Replace×in docstring to satisfy Ruff RUF002.
Use plainxin the phase descriptions to avoid ambiguous Unicode lint warnings.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/utils.py` around lines 219 - 223, The docstring in flashinfer.fused_moe.utils that lists the Phase descriptions uses the Unicode multiplication sign '×', which triggers Ruff RUF002; replace each '×' with a plain ASCII 'x' in that docstring (the lines "Phase 1... (step ×2)" and "Phase 4... (step ×2)") so the phase descriptions read "step x2" to satisfy the linter.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/fused_moe/utils.py`:
- Around line 227-248: The bucket-generation logic in the function ignores
min_num_tokens for phases 2–4 so buckets smaller than min_num_tokens can be
emitted; update the phase start values to respect min_num_tokens by initializing
each phase's m to the maximum of min_num_tokens and the existing phase-start
value (e.g., for Phase 2 set m = max(min_num_tokens, _PHASE1_END +
_PHASE2_STEP), for Phase 3 set m = max(min_num_tokens, _PHASE2_END +
_PHASE3_STEP), and for Phase 4 set m = max(min_num_tokens, _PHASE3_END * 2)) and
keep the existing upper-bound checks (min(max_num_tokens, _PHASEx_END)) and
increments so no bucket below min_num_tokens is appended; apply this change
around the loops that build buckets (the blocks using variables m, _PHASE1_END,
_PHASE2_STEP, _PHASE2_END, _PHASE3_STEP, and _PHASE3_END).
---
Nitpick comments:
In `@flashinfer/fused_moe/utils.py`:
- Around line 219-223: The docstring in flashinfer.fused_moe.utils that lists
the Phase descriptions uses the Unicode multiplication sign '×', which triggers
Ruff RUF002; replace each '×' with a plain ASCII 'x' in that docstring (the
lines "Phase 1... (step ×2)" and "Phase 4... (step ×2)") so the phase
descriptions read "step x2" to satisfy the linter.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 8f753193-443e-4643-a92a-b8a1fe2527c7
📒 Files selected for processing (6)
flashinfer/fused_moe/core.pyflashinfer/fused_moe/cute_dsl/tuner.pyflashinfer/fused_moe/utils.pyflashinfer/gemm/gemm_base.pyflashinfer/trtllm_low_latency_gemm.pyversion.txt
| m = max(min_num_tokens, 1) | ||
| while m <= min(max_num_tokens, _PHASE1_END): | ||
| buckets.append(m) | ||
| m *= 2 | ||
|
|
||
| # Phase 2: linear step 256 in (_PHASE1_END, _PHASE2_END] | ||
| m = _PHASE1_END + _PHASE2_STEP | ||
| while m <= min(max_num_tokens, _PHASE2_END): | ||
| buckets.append(m) | ||
| m += _PHASE2_STEP | ||
|
|
||
| # Phase 3: linear step 512 in (_PHASE2_END, _PHASE3_END] | ||
| m = _PHASE2_END + _PHASE3_STEP | ||
| while m <= min(max_num_tokens, _PHASE3_END): | ||
| buckets.append(m) | ||
| m += _PHASE3_STEP | ||
|
|
||
| # Phase 4: power-of-2 beyond _PHASE3_END | ||
| m = _PHASE3_END * 2 | ||
| while m <= max_num_tokens: | ||
| buckets.append(m) | ||
| m *= 2 |
There was a problem hiding this comment.
min_num_tokens is not honored after phase 1.
For min_num_tokens > 256, phases 2–4 can still emit buckets smaller than min_num_tokens (e.g., 512), which violates the function contract and can under-bucket profiles.
Suggested fix
def get_hybrid_num_tokens_buckets(
max_num_tokens: int, min_num_tokens: int = 1
) -> Tuple[int, ...]:
+ if max_num_tokens < 1:
+ raise ValueError("max_num_tokens must be >= 1")
+
buckets: List[int] = []
# Phase 1: power-of-2 up to _PHASE1_END
- m = max(min_num_tokens, 1)
+ min_num_tokens = max(min_num_tokens, 1)
+ m = next_positive_power_of_2(min_num_tokens)
while m <= min(max_num_tokens, _PHASE1_END):
buckets.append(m)
m *= 2
# Phase 2: linear step 256 in (_PHASE1_END, _PHASE2_END]
- m = _PHASE1_END + _PHASE2_STEP
+ m = _ceil_to_step(max(min_num_tokens, _PHASE1_END + 1), _PHASE2_STEP)
while m <= min(max_num_tokens, _PHASE2_END):
buckets.append(m)
m += _PHASE2_STEP
# Phase 3: linear step 512 in (_PHASE2_END, _PHASE3_END]
- m = _PHASE2_END + _PHASE3_STEP
+ m = _ceil_to_step(max(min_num_tokens, _PHASE2_END + 1), _PHASE3_STEP)
while m <= min(max_num_tokens, _PHASE3_END):
buckets.append(m)
m += _PHASE3_STEP
# Phase 4: power-of-2 beyond _PHASE3_END
- m = _PHASE3_END * 2
+ m = next_positive_power_of_2(max(min_num_tokens, _PHASE3_END + 1))
while m <= max_num_tokens:
buckets.append(m)
m *= 2🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/fused_moe/utils.py` around lines 227 - 248, The bucket-generation
logic in the function ignores min_num_tokens for phases 2–4 so buckets smaller
than min_num_tokens can be emitted; update the phase start values to respect
min_num_tokens by initializing each phase's m to the maximum of min_num_tokens
and the existing phase-start value (e.g., for Phase 2 set m =
max(min_num_tokens, _PHASE1_END + _PHASE2_STEP), for Phase 3 set m =
max(min_num_tokens, _PHASE2_END + _PHASE3_STEP), and for Phase 4 set m =
max(min_num_tokens, _PHASE3_END * 2)) and keep the existing upper-bound checks
(min(max_num_tokens, _PHASEx_END)) and increments so no bucket below
min_num_tokens is appended; apply this change around the loops that build
buckets (the blocks using variables m, _PHASE1_END, _PHASE2_STEP, _PHASE2_END,
_PHASE3_STEP, and _PHASE3_END).
…file Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/autotuner.py`:
- Around line 783-787: The change reintroduces linked-dimension propagation by
iterating with zip(spec.input_idx, spec.dim_idx, strict=True) and assigning
mapped_value to every linked slot (using mapped_value and base_profile), which
reopens a known TRTLLM fused MoE crash; revert to assigning only the first
linked dimension: remove the for-loop that writes mapped_value to all indices
and instead set only base_profile[spec.input_idx[0]][spec.dim_idx[0]] to
spec.map_to_tuning_buckets(base_profile[spec.input_idx[0]][spec.dim_idx[0]]),
keeping the call to spec.map_to_tuning_buckets but preventing propagation across
spec.input_idx/spec.dim_idx.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
| mapped_value = spec.map_to_tuning_buckets( | ||
| base_profile[spec.input_idx[0]][spec.dim_idx[0]] | ||
| ) | ||
| for i_idx, d_idx in zip(spec.input_idx, spec.dim_idx, strict=True): | ||
| base_profile[i_idx][d_idx] = mapped_value |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify whether the C++ fix and test re-enablement for linked-dimension mapping landed in this branch.
set -euo pipefail
echo "== _find_nearest_profile mapping behavior =="
rg -n -C3 '_find_nearest_profile|zip\(spec\.input_idx, spec\.dim_idx|map_to_tuning_buckets' flashinfer/autotuner.py
echo
echo "== Known skipped tests status =="
rg -n -C2 'pytest\.mark\.skip|_find_nearest_profile|linked dimension|DynamicTensorSpec' tests/autotuner/test_autotuner_core.py || true
echo
echo "== TRTLLM launcher key lookup paths =="
rg -n -C3 'launchers_map\.at\(tile_N\)|tile_N|launcher' csrc/trtllm_fused_moe_kernel_launcher.cu || trueRepository: flashinfer-ai/flashinfer
Length of output: 22630
Reintroduces a known linked-dimension crash path that was reverted in PR #2697.
Lines 783–787 propagate the mapped value across all linked dimensions via zip(spec.input_idx, spec.dim_idx). This re-enables behavior that was intentionally reverted because it triggers TRTLLM fused MoE C++ runtime crashes (launchers_map.at(tile_N) missing key). The three regression tests that verify this mapping (test_find_nearest_profile_moe_shared_num_tokens_axis, test_find_nearest_profile_moe_same_bucket_same_profile, test_find_nearest_profile_maps_all_linked_dims) remain skipped with messages stating the propagation was reverted. Unless the corresponding C++ fix is included in this commit, revert to first-dimension-only mapping:
base_profile[spec.input_idx[0]][spec.dim_idx[0]] = spec.map_to_tuning_buckets(
base_profile[spec.input_idx[0]][spec.dim_idx[0]]
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/autotuner.py` around lines 783 - 787, The change reintroduces
linked-dimension propagation by iterating with zip(spec.input_idx, spec.dim_idx,
strict=True) and assigning mapped_value to every linked slot (using mapped_value
and base_profile), which reopens a known TRTLLM fused MoE crash; revert to
assigning only the first linked dimension: remove the for-loop that writes
mapped_value to all indices and instead set only
base_profile[spec.input_idx[0]][spec.dim_idx[0]] to
spec.map_to_tuning_buckets(base_profile[spec.input_idx[0]][spec.dim_idx[0]]),
keeping the call to spec.map_to_tuning_buckets but preventing propagation across
spec.input_idx/spec.dim_idx.
|
Closed. Move to #3115 . |
📌 Description
Power-of-2 tuning buckets create exponentially growing gaps at large
num_tokensvalues (e.g. 1024 gap between bucket1024 and 2048). For MoE workloads,
avg_tokens_per_expertcan cross multiple tile selection boundaries within a singlegap, causing non-power-of-2 inputs (like 1536) to be mapped to a much smaller bucket and receive a kernel optimized for
the wrong workload size — sometimes running slower than larger batch sizes.
This PR replaces pure power-of-2 buckets with a four-phase hybrid spacing scheme: power-of-2 up to 256, linear step 256
up to 2048, linear step 512 up to 4096, then power-of-2 again. This keeps bucket count manageable (21 vs 14 for
max=8192) while ensuring adjacent buckets never span a tile selection boundary.
This PR also fixes a bug in
AutoTuner._find_nearest_profilewhere only the first tensor in a multi-tensorDynamicTensorSpechad its dimension mapped to the nearest bucket. For ops liketrtllm_fp8_per_tensor_scale_moethatdeclare 5 input tensors sharing the same
num_tokensdimension, the remaining 4 tensors retained their original shape —producing a different cache key at inference time than what was stored during profiling. This caused a silent cache
miss and fallback to the default tactic for any non-bucket-boundary token count (e.g. bs=384 mapped to bucket 512 but
always fell back to TileN=8).
🔍 Related Issues
N/A
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are
complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
trtllm-genfused MoE FP8-Per-Tensor (num_experts=24, top_k=8):num_tokens=1536was mapped to bucket 1024, selecting a small-tile kernel (t128x32x256) instead of the large-tilekernel (
t128x128x128) that 2048 gets — making 1536 slower than 2048.cache-missed because
_find_nearest_profileonly bucketed the output buffer tensor, leaving the other 4 input tensorsat their original 384-token shapes. After fixing both, bs=384 goes from 1.00x to 2.40x speedup (Fp8-Per-Tensor,
num_experts=192, top_k=8, hidden=4096, intermediate=1536).
only the trtllm-gen path has been benchmarked.
DynamicTensorSpeclists more than one input tensor sharing a dynamicdimension.
kernels introduce larger tiles, these may need adjustment.
get_power_of_2_num_tokens_bucketsandget_last_power_of_2_num_tokens_bucketsare removed;last_positive_power_of_2is retained as it's still used by autotuner unit tests.