diff --git a/flashinfer/autotuner.py b/flashinfer/autotuner.py index 0cc205858d..e2b25ac6bb 100644 --- a/flashinfer/autotuner.py +++ b/flashinfer/autotuner.py @@ -393,6 +393,16 @@ def search_cache( cache_key = AutoTuner._get_cache_key( custom_op, r, input_shapes, tuning_config ) + if os.environ.get("FLASHINFER_AUTOTUNER_TRACE", "0") == "1": + logger.debug( + f"[AutoTuner] op={custom_op} " + f"runner_hash={hash(r)} " + f"input_shapes={input_shapes} " + f"nearest_profile={cache_key[-1]} " + f"cache_key={cache_key} " + f"hit={cache_key in self.profiling_cache} " + f"cache_size={len(self.profiling_cache)}" + ) if ( os.environ.get("FLASHINFER_AUTOTUNER_LOAD_FROM_FILE", "0") == "1" and not self.is_tuning_mode @@ -770,11 +780,11 @@ def _find_nearest_profile( base_profile = list(list(shape) for shape in shapes) for spec in tuning_config.dynamic_tensor_specs: - base_profile[spec.input_idx[0]][spec.dim_idx[0]] = ( - spec.map_to_tuning_buckets( - base_profile[spec.input_idx[0]][spec.dim_idx[0]] - ) + mapped_value = spec.map_to_tuning_buckets( + base_profile[spec.input_idx[0]][spec.dim_idx[0]] ) + for i_idx, d_idx in zip(spec.input_idx, spec.dim_idx, strict=True): + base_profile[i_idx][d_idx] = mapped_value # associated dimensions dependent on other free dynamic dimensions, so assign -1 in the profile for constraint_spec in tuning_config.constraint_specs: diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 7e0760e7b2..71ab51ce37 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -51,8 +51,8 @@ get_compute_capability, ) from .utils import ( - get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2, + get_hybrid_num_tokens_buckets, + map_to_hybrid_bucket, ) @@ -363,8 +363,8 @@ class MoERunner(TunableRunner): DynamicTensorSpec( (0,), (0,), - get_last_power_of_2_num_tokens_buckets(8192), - lambda x: min(last_positive_power_of_2(x), 8192), + get_hybrid_num_tokens_buckets(8192), + lambda x: map_to_hybrid_bucket(x, 8192), ), ) ) @@ -497,8 +497,8 @@ def refine_tuning_config(cls, tune_max_num_tokens: int): DynamicTensorSpec( (0,), (0,), - get_last_power_of_2_num_tokens_buckets(tune_max_num_tokens), - lambda x: min(last_positive_power_of_2(x), tune_max_num_tokens), + get_hybrid_num_tokens_buckets(tune_max_num_tokens), + lambda x: map_to_hybrid_bucket(x, tune_max_num_tokens), ), ) ) @@ -979,8 +979,8 @@ class MoERunner(TunableRunner): DynamicTensorSpec( (0, 1, 2, 3, 4, 5), (0, 0, 0, 0, 0, 0), - get_last_power_of_2_num_tokens_buckets(8192, 1), - lambda x: min(last_positive_power_of_2(x), 8192), + get_hybrid_num_tokens_buckets(8192, 1), + lambda x: map_to_hybrid_bucket(x, 8192), dynamic_tensor_initializers, ), ) @@ -990,8 +990,8 @@ class MoERunner(TunableRunner): DynamicTensorSpec( (0, 1, 2, 3, 4), (0, 0, 0, 0, 0), - get_last_power_of_2_num_tokens_buckets(8192, 1), - lambda x: min(last_positive_power_of_2(x), 8192), + get_hybrid_num_tokens_buckets(8192, 1), + lambda x: map_to_hybrid_bucket(x, 8192), dynamic_tensor_initializers[:5], ), ), @@ -1291,8 +1291,8 @@ def refine_tuning_config(cls, tune_max_num_tokens: int, **kwargs): DynamicTensorSpec( (0, 1, 2, 3, 4, 5), (0, 0, 0, 0, 0, 0), - get_last_power_of_2_num_tokens_buckets(tune_max_num_tokens, 1), - lambda x: min(last_positive_power_of_2(x), tune_max_num_tokens), + get_hybrid_num_tokens_buckets(tune_max_num_tokens, 1), + lambda x: map_to_hybrid_bucket(x, tune_max_num_tokens), cls.dynamic_tensor_initializers, ), ), @@ -1303,8 +1303,8 @@ def refine_tuning_config(cls, tune_max_num_tokens: int, **kwargs): DynamicTensorSpec( (0, 1, 2, 3, 4), (0, 0, 0, 0, 0), - get_last_power_of_2_num_tokens_buckets(tune_max_num_tokens, 1), - lambda x: min(last_positive_power_of_2(x), tune_max_num_tokens), + get_hybrid_num_tokens_buckets(tune_max_num_tokens, 1), + lambda x: map_to_hybrid_bucket(x, tune_max_num_tokens), cls.dynamic_tensor_initializers[:5], ), ), diff --git a/flashinfer/fused_moe/cute_dsl/tuner.py b/flashinfer/fused_moe/cute_dsl/tuner.py index ad651d21f1..fb3c226de0 100644 --- a/flashinfer/fused_moe/cute_dsl/tuner.py +++ b/flashinfer/fused_moe/cute_dsl/tuner.py @@ -42,8 +42,8 @@ TuningConfig, ) from ..utils import ( - get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2, + get_hybrid_num_tokens_buckets, + map_to_hybrid_bucket, ) logger = logging.getLogger(__name__) @@ -274,8 +274,8 @@ class CuteDslFusedMoENvfp4Runner(TunableRunner): DynamicTensorSpec( input_idx=(0, 1, 2, 3, 11), # x, x_sf, experts, scales, moe_output dim_idx=(0, 0, 0, 0, 0), # First dimension is num_tokens for all - gen_tuning_buckets=get_last_power_of_2_num_tokens_buckets(8192), - map_to_tuning_buckets=lambda x: min(last_positive_power_of_2(x), 8192), + gen_tuning_buckets=get_hybrid_num_tokens_buckets(8192), + map_to_tuning_buckets=lambda x: map_to_hybrid_bucket(x, 8192), tensor_initializers=dynamic_tensor_initializers, ), ), diff --git a/flashinfer/fused_moe/utils.py b/flashinfer/fused_moe/utils.py index 17764b8599..86dcc00144 100644 --- a/flashinfer/fused_moe/utils.py +++ b/flashinfer/fused_moe/utils.py @@ -192,27 +192,102 @@ def nearest_in_buckets(x: int, buckets: List[int]) -> int: return min(max(next_positive_power_of_2(x), buckets[0]), buckets[-1]) -def get_power_of_2_num_tokens_buckets(max_num_tokens) -> Tuple[int]: - max_num_tokens = next_positive_power_of_2(max_num_tokens) - num_token_buckets = [] - m = max_num_tokens - while m >= 1: - num_token_buckets.append(m) - m //= 2 +_PHASE1_END = 256 +_PHASE2_STEP = 256 +_PHASE2_END = 2048 +_PHASE3_STEP = 512 +_PHASE3_END = 4096 - return tuple(num_token_buckets) +def _ceil_to_step(x: int, step: int) -> int: + return ((x + step - 1) // step) * step -def get_last_power_of_2_num_tokens_buckets( - max_num_tokens, min_num_tokens=1 + +def get_hybrid_num_tokens_buckets( + max_num_tokens: int, min_num_tokens: int = 1 ) -> Tuple[int, ...]: - max_num_tokens = last_positive_power_of_2(max_num_tokens) - num_token_buckets = [] - m = max_num_tokens - while m >= min_num_tokens: - num_token_buckets.append(m) - m //= 2 - return tuple(num_token_buckets) + """Generate tuning buckets with adaptive spacing. + + Pure power-of-2 spacing creates huge gaps at large values (e.g. 1024 + between bucket 1024 and 2048). For MoE workloads the + avg_tokens_per_expert can jump across multiple tile boundaries inside a + single gap, forcing the autotuner to pick a kernel optimised for a very + different workload size. + + This function uses four phases with progressively coarser spacing:: + + Phase 1: [min .. 256] — power-of-2 (step ×2) + Phase 2: (256 .. 2048] — linear step 256 + Phase 3: (2048 .. 4096] — linear step 512 + Phase 4: (4096 .. max] — power-of-2 (step ×2) + """ + buckets: List[int] = [] + + # Phase 1: power-of-2 up to _PHASE1_END + m = max(min_num_tokens, 1) + while m <= min(max_num_tokens, _PHASE1_END): + buckets.append(m) + m *= 2 + + # Phase 2: linear step 256 in (_PHASE1_END, _PHASE2_END] + m = _PHASE1_END + _PHASE2_STEP + while m <= min(max_num_tokens, _PHASE2_END): + buckets.append(m) + m += _PHASE2_STEP + + # Phase 3: linear step 512 in (_PHASE2_END, _PHASE3_END] + m = _PHASE2_END + _PHASE3_STEP + while m <= min(max_num_tokens, _PHASE3_END): + buckets.append(m) + m += _PHASE3_STEP + + # Phase 4: power-of-2 beyond _PHASE3_END + m = _PHASE3_END * 2 + while m <= max_num_tokens: + buckets.append(m) + m *= 2 + + if not buckets or buckets[-1] != max_num_tokens: + buckets.append(max_num_tokens) + + return tuple(sorted(set(buckets))) + + +def map_to_hybrid_bucket(x: int, max_num_tokens: int) -> int: + """Map an arbitrary num_tokens to the nearest hybrid bucket (rounding up). + + Mirrors the four-phase spacing of :func:`get_hybrid_num_tokens_buckets`. + The result is clamped to ``[1, max_num_tokens]``. + """ + if x <= 0: + return 1 + if x >= max_num_tokens: + return max_num_tokens + if x <= _PHASE1_END: + return next_positive_power_of_2(x) + if x <= _PHASE2_END: + return min(_ceil_to_step(x, _PHASE2_STEP), max_num_tokens) + if x <= _PHASE3_END: + return min(_ceil_to_step(x, _PHASE3_STEP), max_num_tokens) + return min(next_positive_power_of_2(x), max_num_tokens) + + +def map_to_hybrid_bucket_uncapped(x: int) -> int: + """One-argument variant for use as a function reference in GEMM tuning. + + Same rounding logic as :func:`map_to_hybrid_bucket` but without the + ``max_num_tokens`` clamp (the autotuner already handles upper-bound + clamping via the generated bucket list). + """ + if x <= 0: + return 1 + if x <= _PHASE1_END: + return next_positive_power_of_2(x) + if x <= _PHASE2_END: + return _ceil_to_step(x, _PHASE2_STEP) + if x <= _PHASE3_END: + return _ceil_to_step(x, _PHASE3_STEP) + return next_positive_power_of_2(x) def get_fp4_shape(input_shape, sf_vec_size, is_swizzled_layout=True): diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 60bc5eb76f..1caf50ee1a 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -32,8 +32,8 @@ TuningConfig, ) from ..fused_moe.utils import ( - get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2, + get_hybrid_num_tokens_buckets, + map_to_hybrid_bucket_uncapped, ) from ..utils import ( get_native_fp4_dtype, @@ -824,8 +824,8 @@ def forward( DynamicTensorSpec( (0,), # a_tensor_index (-2,), - get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2, + get_hybrid_num_tokens_buckets, + map_to_hybrid_bucket_uncapped, ), ), constraint_specs=( @@ -880,8 +880,8 @@ def forward( DynamicTensorSpec( (0,), # a_tensor_index (-2,), - get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2, + get_hybrid_num_tokens_buckets, + map_to_hybrid_bucket_uncapped, ), ), constraint_specs=( @@ -1182,8 +1182,8 @@ def tgv_gemm_sm100( DynamicTensorSpec( (a_tensor_index,), (-2,), - get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2, + get_hybrid_num_tokens_buckets, + map_to_hybrid_bucket_uncapped, ), ), constraint_specs=( @@ -3904,8 +3904,8 @@ def _mxfp8_swizzled_scale_len(m: int, k: int) -> int: DynamicTensorSpec( (0,), # a_tensor_index (0,), - get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2, + get_hybrid_num_tokens_buckets, + map_to_hybrid_bucket_uncapped, ), ), constraint_specs=( @@ -3928,8 +3928,8 @@ def _mxfp8_swizzled_scale_len(m: int, k: int) -> int: DynamicTensorSpec( (0,), # a_tensor_index (0,), - get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2, + get_hybrid_num_tokens_buckets, + map_to_hybrid_bucket_uncapped, ), ), constraint_specs=( @@ -3952,8 +3952,8 @@ def _mxfp8_swizzled_scale_len(m: int, k: int) -> int: DynamicTensorSpec( (0,), # a_tensor_index (0,), - get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2, + get_hybrid_num_tokens_buckets, + map_to_hybrid_bucket_uncapped, ), ), constraint_specs=( diff --git a/flashinfer/trtllm_low_latency_gemm.py b/flashinfer/trtllm_low_latency_gemm.py index e7beb51343..ab7cef1f82 100644 --- a/flashinfer/trtllm_low_latency_gemm.py +++ b/flashinfer/trtllm_low_latency_gemm.py @@ -36,8 +36,8 @@ OptimizationProfile, ) from flashinfer.fused_moe.utils import ( - get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2, + get_hybrid_num_tokens_buckets, + map_to_hybrid_bucket_uncapped, ) from flashinfer.jit import setup_cubin_loader from flashinfer.utils import _get_cache_buf @@ -170,8 +170,8 @@ def trtllm_low_latency_gemm( DynamicTensorSpec( (a_tensor_index,), (-2,), - get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2, + get_hybrid_num_tokens_buckets, + map_to_hybrid_bucket_uncapped, ), ), constraint_specs=( diff --git a/version.txt b/version.txt index ef5e445445..05e8a4593f 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.6.5 +0.6.6