Fix softmax multi-block path for n_cols > 65536 (unreachable + non-compiling)#1252
Open
lollinng wants to merge 2 commits into
Open
Fix softmax multi-block path for n_cols > 65536 (unreachable + non-compiling)#1252lollinng wants to merge 2 commits into
lollinng wants to merge 2 commits into
Conversation
The multi-block softmax path was unreachable and broken: - _softmax_forward used BLOCK_SIZE = calculate_settings(n_cols) = next_power_of_2(n_cols) >= n_cols, so 'n_cols <= BLOCK_SIZE' was always true and the multi-block kernel never ran; for n_cols > 65536 calculate_settings raised instead. - the multi-block kernels used tl.float32(-inf)/tl.float32(0.0) (not valid value constructors) and tl.max(m, blk_max) (a reduction) where elementwise tl.maximum was meant, so they would not even compile. Dispatch to the single-block kernel for n_cols <= MAX_FUSED_SIZE (unchanged) and to the multi-block kernel, capped to a MAX_FUSED_SIZE tile, for wider rows. Fix the kernel initializers and the max op. Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The existing 'multi-block dispatch' test shapes (e.g. (1, 4096)) are <= MAX_FUSED_SIZE and actually take the single-block path, so the multi-block kernels were never exercised. Add (2, 70000) / (1, 3, 70000) shapes that go through the multi-block path, and correct the misleading comments. Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
liger_softmaxcould not handle rows wider thanMAX_FUSED_SIZE(65536), and the multi-block kernels meant to handle them were both unreachable and broken:_softmax_forwardcomputedBLOCK_SIZE, num_warps = calculate_settings(n_cols), andcalculate_settingsreturnsBLOCK_SIZE = next_power_of_2(n_cols) >= n_cols. Soif n_cols <= BLOCK_SIZEwas always true → the single-block kernel was always chosen and_softmax_multi_block_forward_kernelwas dead code. Forn_cols > 65536,calculate_settingsraisedRuntimeErrorbefore any dispatch, so wide-vocab softmax simply errored out.tl.float32(-float("inf"))/tl.float32(0.0)(not valid value constructors) andtl.max(m, blk_max)— a reduction — where the elementwisetl.maximumwas intended. They would not compile.(The existing
(1, 4096)"multi-block dispatch" test shapes are<= 65536and actually hit the single-block path, so this was never caught.)Fix
n_cols <= MAX_FUSED_SIZE(unchanged), and to the multi-block kernel — capped to aMAX_FUSED_SIZEtile that it loops over — for wider rows.m = -inf,d = 0.0,acc = 0.0) and usetl.maximumfor the running max.Verification (NVIDIA T4)
Forward and backward vs
torch.softmax, fp32:Added
(2, 70000)/(1, 3, 70000)shapes totest_softmax.pyso the multi-block path is actually exercised, and corrected the misleading single-block comments.