Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions flashinfer/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,16 @@ def search_cache(
cache_key = AutoTuner._get_cache_key(
custom_op, r, input_shapes, tuning_config
)
if os.environ.get("FLASHINFER_AUTOTUNER_TRACE", "0") == "1":
logger.debug(
f"[AutoTuner] op={custom_op} "
f"runner_hash={hash(r)} "
f"input_shapes={input_shapes} "
f"nearest_profile={cache_key[-1]} "
f"cache_key={cache_key} "
f"hit={cache_key in self.profiling_cache} "
f"cache_size={len(self.profiling_cache)}"
)
if (
os.environ.get("FLASHINFER_AUTOTUNER_LOAD_FROM_FILE", "0") == "1"
and not self.is_tuning_mode
Expand Down Expand Up @@ -770,11 +780,11 @@ def _find_nearest_profile(
base_profile = list(list(shape) for shape in shapes)

for spec in tuning_config.dynamic_tensor_specs:
base_profile[spec.input_idx[0]][spec.dim_idx[0]] = (
spec.map_to_tuning_buckets(
base_profile[spec.input_idx[0]][spec.dim_idx[0]]
)
mapped_value = spec.map_to_tuning_buckets(
base_profile[spec.input_idx[0]][spec.dim_idx[0]]
)
for i_idx, d_idx in zip(spec.input_idx, spec.dim_idx, strict=True):
base_profile[i_idx][d_idx] = mapped_value
Comment on lines +783 to +787
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | πŸ”΄ Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify whether the C++ fix and test re-enablement for linked-dimension mapping landed in this branch.

set -euo pipefail

echo "== _find_nearest_profile mapping behavior =="
rg -n -C3 '_find_nearest_profile|zip\(spec\.input_idx, spec\.dim_idx|map_to_tuning_buckets' flashinfer/autotuner.py

echo
echo "== Known skipped tests status =="
rg -n -C2 'pytest\.mark\.skip|_find_nearest_profile|linked dimension|DynamicTensorSpec' tests/autotuner/test_autotuner_core.py || true

echo
echo "== TRTLLM launcher key lookup paths =="
rg -n -C3 'launchers_map\.at\(tile_N\)|tile_N|launcher' csrc/trtllm_fused_moe_kernel_launcher.cu || true

Repository: flashinfer-ai/flashinfer

Length of output: 22630


Reintroduces a known linked-dimension crash path that was reverted in PR #2697.

Lines 783–787 propagate the mapped value across all linked dimensions via zip(spec.input_idx, spec.dim_idx). This re-enables behavior that was intentionally reverted because it triggers TRTLLM fused MoE C++ runtime crashes (launchers_map.at(tile_N) missing key). The three regression tests that verify this mapping (test_find_nearest_profile_moe_shared_num_tokens_axis, test_find_nearest_profile_moe_same_bucket_same_profile, test_find_nearest_profile_maps_all_linked_dims) remain skipped with messages stating the propagation was reverted. Unless the corresponding C++ fix is included in this commit, revert to first-dimension-only mapping:

base_profile[spec.input_idx[0]][spec.dim_idx[0]] = spec.map_to_tuning_buckets(
    base_profile[spec.input_idx[0]][spec.dim_idx[0]]
)
πŸ€– Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/autotuner.py` around lines 783 - 787, The change reintroduces
linked-dimension propagation by iterating with zip(spec.input_idx, spec.dim_idx,
strict=True) and assigning mapped_value to every linked slot (using mapped_value
and base_profile), which reopens a known TRTLLM fused MoE crash; revert to
assigning only the first linked dimension: remove the for-loop that writes
mapped_value to all indices and instead set only
base_profile[spec.input_idx[0]][spec.dim_idx[0]] to
spec.map_to_tuning_buckets(base_profile[spec.input_idx[0]][spec.dim_idx[0]]),
keeping the call to spec.map_to_tuning_buckets but preventing propagation across
spec.input_idx/spec.dim_idx.


# associated dimensions dependent on other free dynamic dimensions, so assign -1 in the profile
for constraint_spec in tuning_config.constraint_specs:
Expand Down
28 changes: 14 additions & 14 deletions flashinfer/fused_moe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@
get_compute_capability,
)
from .utils import (
get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2,
get_hybrid_num_tokens_buckets,
map_to_hybrid_bucket,
)


Expand Down Expand Up @@ -363,8 +363,8 @@ class MoERunner(TunableRunner):
DynamicTensorSpec(
(0,),
(0,),
get_last_power_of_2_num_tokens_buckets(8192),
lambda x: min(last_positive_power_of_2(x), 8192),
get_hybrid_num_tokens_buckets(8192),
lambda x: map_to_hybrid_bucket(x, 8192),
),
)
)
Expand Down Expand Up @@ -497,8 +497,8 @@ def refine_tuning_config(cls, tune_max_num_tokens: int):
DynamicTensorSpec(
(0,),
(0,),
get_last_power_of_2_num_tokens_buckets(tune_max_num_tokens),
lambda x: min(last_positive_power_of_2(x), tune_max_num_tokens),
get_hybrid_num_tokens_buckets(tune_max_num_tokens),
lambda x: map_to_hybrid_bucket(x, tune_max_num_tokens),
),
)
)
Expand Down Expand Up @@ -979,8 +979,8 @@ class MoERunner(TunableRunner):
DynamicTensorSpec(
(0, 1, 2, 3, 4, 5),
(0, 0, 0, 0, 0, 0),
get_last_power_of_2_num_tokens_buckets(8192, 1),
lambda x: min(last_positive_power_of_2(x), 8192),
get_hybrid_num_tokens_buckets(8192, 1),
lambda x: map_to_hybrid_bucket(x, 8192),
dynamic_tensor_initializers,
),
)
Expand All @@ -990,8 +990,8 @@ class MoERunner(TunableRunner):
DynamicTensorSpec(
(0, 1, 2, 3, 4),
(0, 0, 0, 0, 0),
get_last_power_of_2_num_tokens_buckets(8192, 1),
lambda x: min(last_positive_power_of_2(x), 8192),
get_hybrid_num_tokens_buckets(8192, 1),
lambda x: map_to_hybrid_bucket(x, 8192),
dynamic_tensor_initializers[:5],
),
),
Expand Down Expand Up @@ -1291,8 +1291,8 @@ def refine_tuning_config(cls, tune_max_num_tokens: int, **kwargs):
DynamicTensorSpec(
(0, 1, 2, 3, 4, 5),
(0, 0, 0, 0, 0, 0),
get_last_power_of_2_num_tokens_buckets(tune_max_num_tokens, 1),
lambda x: min(last_positive_power_of_2(x), tune_max_num_tokens),
get_hybrid_num_tokens_buckets(tune_max_num_tokens, 1),
lambda x: map_to_hybrid_bucket(x, tune_max_num_tokens),
cls.dynamic_tensor_initializers,
),
),
Expand All @@ -1303,8 +1303,8 @@ def refine_tuning_config(cls, tune_max_num_tokens: int, **kwargs):
DynamicTensorSpec(
(0, 1, 2, 3, 4),
(0, 0, 0, 0, 0),
get_last_power_of_2_num_tokens_buckets(tune_max_num_tokens, 1),
lambda x: min(last_positive_power_of_2(x), tune_max_num_tokens),
get_hybrid_num_tokens_buckets(tune_max_num_tokens, 1),
lambda x: map_to_hybrid_bucket(x, tune_max_num_tokens),
cls.dynamic_tensor_initializers[:5],
),
),
Expand Down
8 changes: 4 additions & 4 deletions flashinfer/fused_moe/cute_dsl/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@
TuningConfig,
)
from ..utils import (
get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2,
get_hybrid_num_tokens_buckets,
map_to_hybrid_bucket,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -274,8 +274,8 @@ class CuteDslFusedMoENvfp4Runner(TunableRunner):
DynamicTensorSpec(
input_idx=(0, 1, 2, 3, 11), # x, x_sf, experts, scales, moe_output
dim_idx=(0, 0, 0, 0, 0), # First dimension is num_tokens for all
gen_tuning_buckets=get_last_power_of_2_num_tokens_buckets(8192),
map_to_tuning_buckets=lambda x: min(last_positive_power_of_2(x), 8192),
gen_tuning_buckets=get_hybrid_num_tokens_buckets(8192),
map_to_tuning_buckets=lambda x: map_to_hybrid_bucket(x, 8192),
tensor_initializers=dynamic_tensor_initializers,
),
),
Expand Down
109 changes: 92 additions & 17 deletions flashinfer/fused_moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,27 +192,102 @@ def nearest_in_buckets(x: int, buckets: List[int]) -> int:
return min(max(next_positive_power_of_2(x), buckets[0]), buckets[-1])


def get_power_of_2_num_tokens_buckets(max_num_tokens) -> Tuple[int]:
max_num_tokens = next_positive_power_of_2(max_num_tokens)
num_token_buckets = []
m = max_num_tokens
while m >= 1:
num_token_buckets.append(m)
m //= 2
_PHASE1_END = 256
_PHASE2_STEP = 256
_PHASE2_END = 2048
_PHASE3_STEP = 512
_PHASE3_END = 4096

return tuple(num_token_buckets)

def _ceil_to_step(x: int, step: int) -> int:
return ((x + step - 1) // step) * step

def get_last_power_of_2_num_tokens_buckets(
max_num_tokens, min_num_tokens=1

def get_hybrid_num_tokens_buckets(
max_num_tokens: int, min_num_tokens: int = 1
) -> Tuple[int, ...]:
max_num_tokens = last_positive_power_of_2(max_num_tokens)
num_token_buckets = []
m = max_num_tokens
while m >= min_num_tokens:
num_token_buckets.append(m)
m //= 2
return tuple(num_token_buckets)
"""Generate tuning buckets with adaptive spacing.

Pure power-of-2 spacing creates huge gaps at large values (e.g. 1024
between bucket 1024 and 2048). For MoE workloads the
avg_tokens_per_expert can jump across multiple tile boundaries inside a
single gap, forcing the autotuner to pick a kernel optimised for a very
different workload size.

This function uses four phases with progressively coarser spacing::

Phase 1: [min .. 256] β€” power-of-2 (step Γ—2)
Phase 2: (256 .. 2048] β€” linear step 256
Phase 3: (2048 .. 4096] β€” linear step 512
Phase 4: (4096 .. max] β€” power-of-2 (step Γ—2)
"""
buckets: List[int] = []

# Phase 1: power-of-2 up to _PHASE1_END
m = max(min_num_tokens, 1)
while m <= min(max_num_tokens, _PHASE1_END):
buckets.append(m)
m *= 2

# Phase 2: linear step 256 in (_PHASE1_END, _PHASE2_END]
m = _PHASE1_END + _PHASE2_STEP
while m <= min(max_num_tokens, _PHASE2_END):
buckets.append(m)
m += _PHASE2_STEP

# Phase 3: linear step 512 in (_PHASE2_END, _PHASE3_END]
m = _PHASE2_END + _PHASE3_STEP
while m <= min(max_num_tokens, _PHASE3_END):
buckets.append(m)
m += _PHASE3_STEP

# Phase 4: power-of-2 beyond _PHASE3_END
m = _PHASE3_END * 2
while m <= max_num_tokens:
buckets.append(m)
m *= 2
Comment on lines +227 to +248
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟑 Minor

min_num_tokens is not honored after phase 1.
For min_num_tokens > 256, phases 2–4 can still emit buckets smaller than min_num_tokens (e.g., 512), which violates the function contract and can under-bucket profiles.

Suggested fix
 def get_hybrid_num_tokens_buckets(
     max_num_tokens: int, min_num_tokens: int = 1
 ) -> Tuple[int, ...]:
+    if max_num_tokens < 1:
+        raise ValueError("max_num_tokens must be >= 1")
+
     buckets: List[int] = []
 
     # Phase 1: power-of-2 up to _PHASE1_END
-    m = max(min_num_tokens, 1)
+    min_num_tokens = max(min_num_tokens, 1)
+    m = next_positive_power_of_2(min_num_tokens)
     while m <= min(max_num_tokens, _PHASE1_END):
         buckets.append(m)
         m *= 2
 
     # Phase 2: linear step 256 in (_PHASE1_END, _PHASE2_END]
-    m = _PHASE1_END + _PHASE2_STEP
+    m = _ceil_to_step(max(min_num_tokens, _PHASE1_END + 1), _PHASE2_STEP)
     while m <= min(max_num_tokens, _PHASE2_END):
         buckets.append(m)
         m += _PHASE2_STEP
 
     # Phase 3: linear step 512 in (_PHASE2_END, _PHASE3_END]
-    m = _PHASE2_END + _PHASE3_STEP
+    m = _ceil_to_step(max(min_num_tokens, _PHASE2_END + 1), _PHASE3_STEP)
     while m <= min(max_num_tokens, _PHASE3_END):
         buckets.append(m)
         m += _PHASE3_STEP
 
     # Phase 4: power-of-2 beyond _PHASE3_END
-    m = _PHASE3_END * 2
+    m = next_positive_power_of_2(max(min_num_tokens, _PHASE3_END + 1))
     while m <= max_num_tokens:
         buckets.append(m)
         m *= 2
πŸ€– Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/utils.py` around lines 227 - 248, The bucket-generation
logic in the function ignores min_num_tokens for phases 2–4 so buckets smaller
than min_num_tokens can be emitted; update the phase start values to respect
min_num_tokens by initializing each phase's m to the maximum of min_num_tokens
and the existing phase-start value (e.g., for Phase 2 set m =
max(min_num_tokens, _PHASE1_END + _PHASE2_STEP), for Phase 3 set m =
max(min_num_tokens, _PHASE2_END + _PHASE3_STEP), and for Phase 4 set m =
max(min_num_tokens, _PHASE3_END * 2)) and keep the existing upper-bound checks
(min(max_num_tokens, _PHASEx_END)) and increments so no bucket below
min_num_tokens is appended; apply this change around the loops that build
buckets (the blocks using variables m, _PHASE1_END, _PHASE2_STEP, _PHASE2_END,
_PHASE3_STEP, and _PHASE3_END).


if not buckets or buckets[-1] != max_num_tokens:
buckets.append(max_num_tokens)

return tuple(sorted(set(buckets)))


def map_to_hybrid_bucket(x: int, max_num_tokens: int) -> int:
"""Map an arbitrary num_tokens to the nearest hybrid bucket (rounding up).

Mirrors the four-phase spacing of :func:`get_hybrid_num_tokens_buckets`.
The result is clamped to ``[1, max_num_tokens]``.
"""
if x <= 0:
return 1
if x >= max_num_tokens:
return max_num_tokens
if x <= _PHASE1_END:
return next_positive_power_of_2(x)
if x <= _PHASE2_END:
return min(_ceil_to_step(x, _PHASE2_STEP), max_num_tokens)
if x <= _PHASE3_END:
return min(_ceil_to_step(x, _PHASE3_STEP), max_num_tokens)
return min(next_positive_power_of_2(x), max_num_tokens)


def map_to_hybrid_bucket_uncapped(x: int) -> int:
"""One-argument variant for use as a function reference in GEMM tuning.

Same rounding logic as :func:`map_to_hybrid_bucket` but without the
``max_num_tokens`` clamp (the autotuner already handles upper-bound
clamping via the generated bucket list).
"""
if x <= 0:
return 1
if x <= _PHASE1_END:
return next_positive_power_of_2(x)
if x <= _PHASE2_END:
return _ceil_to_step(x, _PHASE2_STEP)
if x <= _PHASE3_END:
return _ceil_to_step(x, _PHASE3_STEP)
return next_positive_power_of_2(x)


def get_fp4_shape(input_shape, sf_vec_size, is_swizzled_layout=True):
Expand Down
28 changes: 14 additions & 14 deletions flashinfer/gemm/gemm_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
TuningConfig,
)
from ..fused_moe.utils import (
get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2,
get_hybrid_num_tokens_buckets,
map_to_hybrid_bucket_uncapped,
)
from ..utils import (
get_native_fp4_dtype,
Expand Down Expand Up @@ -824,8 +824,8 @@ def forward(
DynamicTensorSpec(
(0,), # a_tensor_index
(-2,),
get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2,
get_hybrid_num_tokens_buckets,
map_to_hybrid_bucket_uncapped,
),
),
constraint_specs=(
Expand Down Expand Up @@ -880,8 +880,8 @@ def forward(
DynamicTensorSpec(
(0,), # a_tensor_index
(-2,),
get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2,
get_hybrid_num_tokens_buckets,
map_to_hybrid_bucket_uncapped,
),
),
constraint_specs=(
Expand Down Expand Up @@ -1182,8 +1182,8 @@ def tgv_gemm_sm100(
DynamicTensorSpec(
(a_tensor_index,),
(-2,),
get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2,
get_hybrid_num_tokens_buckets,
map_to_hybrid_bucket_uncapped,
),
),
constraint_specs=(
Expand Down Expand Up @@ -3904,8 +3904,8 @@ def _mxfp8_swizzled_scale_len(m: int, k: int) -> int:
DynamicTensorSpec(
(0,), # a_tensor_index
(0,),
get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2,
get_hybrid_num_tokens_buckets,
map_to_hybrid_bucket_uncapped,
),
),
constraint_specs=(
Expand All @@ -3928,8 +3928,8 @@ def _mxfp8_swizzled_scale_len(m: int, k: int) -> int:
DynamicTensorSpec(
(0,), # a_tensor_index
(0,),
get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2,
get_hybrid_num_tokens_buckets,
map_to_hybrid_bucket_uncapped,
),
),
constraint_specs=(
Expand All @@ -3952,8 +3952,8 @@ def _mxfp8_swizzled_scale_len(m: int, k: int) -> int:
DynamicTensorSpec(
(0,), # a_tensor_index
(0,),
get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2,
get_hybrid_num_tokens_buckets,
map_to_hybrid_bucket_uncapped,
),
),
constraint_specs=(
Expand Down
8 changes: 4 additions & 4 deletions flashinfer/trtllm_low_latency_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
OptimizationProfile,
)
from flashinfer.fused_moe.utils import (
get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2,
get_hybrid_num_tokens_buckets,
map_to_hybrid_bucket_uncapped,
)
from flashinfer.jit import setup_cubin_loader
from flashinfer.utils import _get_cache_buf
Expand Down Expand Up @@ -170,8 +170,8 @@ def trtllm_low_latency_gemm(
DynamicTensorSpec(
(a_tensor_index,),
(-2,),
get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2,
get_hybrid_num_tokens_buckets,
map_to_hybrid_bucket_uncapped,
),
),
constraint_specs=(
Expand Down
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.6.5
0.6.6
Loading