-
Notifications
You must be signed in to change notification settings - Fork 937
fix(autotuner): hybrid bucket spacing and cache-key fix for fused MoE #3063
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
StudyingShao
wants to merge
3
commits into
flashinfer-ai:main
from
StudyingShao:jiangs/autotuner-hybrid-bucket-spacing
Closed
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -192,27 +192,102 @@ def nearest_in_buckets(x: int, buckets: List[int]) -> int: | |
| return min(max(next_positive_power_of_2(x), buckets[0]), buckets[-1]) | ||
|
|
||
|
|
||
| def get_power_of_2_num_tokens_buckets(max_num_tokens) -> Tuple[int]: | ||
| max_num_tokens = next_positive_power_of_2(max_num_tokens) | ||
| num_token_buckets = [] | ||
| m = max_num_tokens | ||
| while m >= 1: | ||
| num_token_buckets.append(m) | ||
| m //= 2 | ||
| _PHASE1_END = 256 | ||
| _PHASE2_STEP = 256 | ||
| _PHASE2_END = 2048 | ||
| _PHASE3_STEP = 512 | ||
| _PHASE3_END = 4096 | ||
|
|
||
| return tuple(num_token_buckets) | ||
|
|
||
| def _ceil_to_step(x: int, step: int) -> int: | ||
| return ((x + step - 1) // step) * step | ||
|
|
||
| def get_last_power_of_2_num_tokens_buckets( | ||
| max_num_tokens, min_num_tokens=1 | ||
|
|
||
| def get_hybrid_num_tokens_buckets( | ||
| max_num_tokens: int, min_num_tokens: int = 1 | ||
| ) -> Tuple[int, ...]: | ||
| max_num_tokens = last_positive_power_of_2(max_num_tokens) | ||
| num_token_buckets = [] | ||
| m = max_num_tokens | ||
| while m >= min_num_tokens: | ||
| num_token_buckets.append(m) | ||
| m //= 2 | ||
| return tuple(num_token_buckets) | ||
| """Generate tuning buckets with adaptive spacing. | ||
|
|
||
| Pure power-of-2 spacing creates huge gaps at large values (e.g. 1024 | ||
| between bucket 1024 and 2048). For MoE workloads the | ||
| avg_tokens_per_expert can jump across multiple tile boundaries inside a | ||
| single gap, forcing the autotuner to pick a kernel optimised for a very | ||
| different workload size. | ||
|
|
||
| This function uses four phases with progressively coarser spacing:: | ||
|
|
||
| Phase 1: [min .. 256] β power-of-2 (step Γ2) | ||
| Phase 2: (256 .. 2048] β linear step 256 | ||
| Phase 3: (2048 .. 4096] β linear step 512 | ||
| Phase 4: (4096 .. max] β power-of-2 (step Γ2) | ||
| """ | ||
| buckets: List[int] = [] | ||
|
|
||
| # Phase 1: power-of-2 up to _PHASE1_END | ||
| m = max(min_num_tokens, 1) | ||
| while m <= min(max_num_tokens, _PHASE1_END): | ||
| buckets.append(m) | ||
| m *= 2 | ||
|
|
||
| # Phase 2: linear step 256 in (_PHASE1_END, _PHASE2_END] | ||
| m = _PHASE1_END + _PHASE2_STEP | ||
| while m <= min(max_num_tokens, _PHASE2_END): | ||
| buckets.append(m) | ||
| m += _PHASE2_STEP | ||
|
|
||
| # Phase 3: linear step 512 in (_PHASE2_END, _PHASE3_END] | ||
| m = _PHASE2_END + _PHASE3_STEP | ||
| while m <= min(max_num_tokens, _PHASE3_END): | ||
| buckets.append(m) | ||
| m += _PHASE3_STEP | ||
|
|
||
| # Phase 4: power-of-2 beyond _PHASE3_END | ||
| m = _PHASE3_END * 2 | ||
| while m <= max_num_tokens: | ||
| buckets.append(m) | ||
| m *= 2 | ||
|
Comment on lines
+227
to
+248
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested fix def get_hybrid_num_tokens_buckets(
max_num_tokens: int, min_num_tokens: int = 1
) -> Tuple[int, ...]:
+ if max_num_tokens < 1:
+ raise ValueError("max_num_tokens must be >= 1")
+
buckets: List[int] = []
# Phase 1: power-of-2 up to _PHASE1_END
- m = max(min_num_tokens, 1)
+ min_num_tokens = max(min_num_tokens, 1)
+ m = next_positive_power_of_2(min_num_tokens)
while m <= min(max_num_tokens, _PHASE1_END):
buckets.append(m)
m *= 2
# Phase 2: linear step 256 in (_PHASE1_END, _PHASE2_END]
- m = _PHASE1_END + _PHASE2_STEP
+ m = _ceil_to_step(max(min_num_tokens, _PHASE1_END + 1), _PHASE2_STEP)
while m <= min(max_num_tokens, _PHASE2_END):
buckets.append(m)
m += _PHASE2_STEP
# Phase 3: linear step 512 in (_PHASE2_END, _PHASE3_END]
- m = _PHASE2_END + _PHASE3_STEP
+ m = _ceil_to_step(max(min_num_tokens, _PHASE2_END + 1), _PHASE3_STEP)
while m <= min(max_num_tokens, _PHASE3_END):
buckets.append(m)
m += _PHASE3_STEP
# Phase 4: power-of-2 beyond _PHASE3_END
- m = _PHASE3_END * 2
+ m = next_positive_power_of_2(max(min_num_tokens, _PHASE3_END + 1))
while m <= max_num_tokens:
buckets.append(m)
m *= 2π€ Prompt for AI Agents |
||
|
|
||
| if not buckets or buckets[-1] != max_num_tokens: | ||
| buckets.append(max_num_tokens) | ||
|
|
||
| return tuple(sorted(set(buckets))) | ||
|
|
||
|
|
||
| def map_to_hybrid_bucket(x: int, max_num_tokens: int) -> int: | ||
| """Map an arbitrary num_tokens to the nearest hybrid bucket (rounding up). | ||
|
|
||
| Mirrors the four-phase spacing of :func:`get_hybrid_num_tokens_buckets`. | ||
| The result is clamped to ``[1, max_num_tokens]``. | ||
| """ | ||
| if x <= 0: | ||
| return 1 | ||
| if x >= max_num_tokens: | ||
| return max_num_tokens | ||
| if x <= _PHASE1_END: | ||
| return next_positive_power_of_2(x) | ||
| if x <= _PHASE2_END: | ||
| return min(_ceil_to_step(x, _PHASE2_STEP), max_num_tokens) | ||
| if x <= _PHASE3_END: | ||
| return min(_ceil_to_step(x, _PHASE3_STEP), max_num_tokens) | ||
| return min(next_positive_power_of_2(x), max_num_tokens) | ||
|
|
||
|
|
||
| def map_to_hybrid_bucket_uncapped(x: int) -> int: | ||
| """One-argument variant for use as a function reference in GEMM tuning. | ||
|
|
||
| Same rounding logic as :func:`map_to_hybrid_bucket` but without the | ||
| ``max_num_tokens`` clamp (the autotuner already handles upper-bound | ||
| clamping via the generated bucket list). | ||
| """ | ||
| if x <= 0: | ||
| return 1 | ||
| if x <= _PHASE1_END: | ||
| return next_positive_power_of_2(x) | ||
| if x <= _PHASE2_END: | ||
| return _ceil_to_step(x, _PHASE2_STEP) | ||
| if x <= _PHASE3_END: | ||
| return _ceil_to_step(x, _PHASE3_STEP) | ||
| return next_positive_power_of_2(x) | ||
|
|
||
|
|
||
| def get_fp4_shape(input_shape, sf_vec_size, is_swizzled_layout=True): | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1 @@ | ||
| 0.6.5 | ||
| 0.6.6 |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
π§© Analysis chain
π Script executed:
Repository: flashinfer-ai/flashinfer
Length of output: 22630
Reintroduces a known linked-dimension crash path that was reverted in PR
#2697.Lines 783β787 propagate the mapped value across all linked dimensions via
zip(spec.input_idx, spec.dim_idx). This re-enables behavior that was intentionally reverted because it triggers TRTLLM fused MoE C++ runtime crashes (launchers_map.at(tile_N)missing key). The three regression tests that verify this mapping (test_find_nearest_profile_moe_shared_num_tokens_axis,test_find_nearest_profile_moe_same_bucket_same_profile,test_find_nearest_profile_maps_all_linked_dims) remain skipped with messages stating the propagation was reverted. Unless the corresponding C++ fix is included in this commit, revert to first-dimension-only mapping:π€ Prompt for AI Agents