Skip to content

Fix softmax multi-block path for n_cols > 65536 (unreachable + non-compiling)#1252

Open
lollinng wants to merge 2 commits into
linkedin:mainfrom
lollinng:fix/softmax-large-vocab-multiblock
Open

Fix softmax multi-block path for n_cols > 65536 (unreachable + non-compiling)#1252
lollinng wants to merge 2 commits into
linkedin:mainfrom
lollinng:fix/softmax-large-vocab-multiblock

Conversation

@lollinng

@lollinng lollinng commented Jun 4, 2026

Copy link
Copy Markdown

Problem

liger_softmax could not handle rows wider than MAX_FUSED_SIZE (65536), and the multi-block kernels meant to handle them were both unreachable and broken:

  1. Unreachable dispatch. _softmax_forward computed BLOCK_SIZE, num_warps = calculate_settings(n_cols), and calculate_settings returns BLOCK_SIZE = next_power_of_2(n_cols) >= n_cols. So if n_cols <= BLOCK_SIZE was always true → the single-block kernel was always chosen and _softmax_multi_block_forward_kernel was dead code. For n_cols > 65536, calculate_settings raised RuntimeError before any dispatch, so wide-vocab softmax simply errored out.
  2. Broken kernels. Even if reached, the multi-block kernels used tl.float32(-float("inf")) / tl.float32(0.0) (not valid value constructors) and tl.max(m, blk_max) — a reduction — where the elementwise tl.maximum was intended. They would not compile.

(The existing (1, 4096) "multi-block dispatch" test shapes are <= 65536 and actually hit the single-block path, so this was never caught.)

Fix

  • Dispatch to the single-block kernel for n_cols <= MAX_FUSED_SIZE (unchanged), and to the multi-block kernel — capped to a MAX_FUSED_SIZE tile that it loops over — for wider rows.
  • Fix the kernel initializers (m = -inf, d = 0.0, acc = 0.0) and use tl.maximum for the running max.

Verification (NVIDIA T4)

Forward and backward vs torch.softmax, fp32:

n_cols=   1024 [single] fwd_match=True bwd_match=True rowsum~1=True
n_cols=  65536 [single] fwd_match=True bwd_match=True rowsum~1=True
n_cols=  70000 [multi]  fwd_match=True bwd_match=True rowsum~1=True
n_cols= 100000 [multi]  fwd_match=True bwd_match=True rowsum~1=True   (non-power-of-2 -> tail masking)
n_cols= 131072 [multi]  fwd_match=True bwd_match=True rowsum~1=True
VERIFY_RESULT: PASS

Added (2, 70000) / (1, 3, 70000) shapes to test_softmax.py so the multi-block path is actually exercised, and corrected the misleading single-block comments.

lollinng and others added 2 commits June 5, 2026 02:42
The multi-block softmax path was unreachable and broken:
- _softmax_forward used BLOCK_SIZE = calculate_settings(n_cols) =
  next_power_of_2(n_cols) >= n_cols, so 'n_cols <= BLOCK_SIZE' was always
  true and the multi-block kernel never ran; for n_cols > 65536
  calculate_settings raised instead.
- the multi-block kernels used tl.float32(-inf)/tl.float32(0.0) (not valid
  value constructors) and tl.max(m, blk_max) (a reduction) where elementwise
  tl.maximum was meant, so they would not even compile.

Dispatch to the single-block kernel for n_cols <= MAX_FUSED_SIZE (unchanged)
and to the multi-block kernel, capped to a MAX_FUSED_SIZE tile, for wider
rows. Fix the kernel initializers and the max op.

Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The existing 'multi-block dispatch' test shapes (e.g. (1, 4096)) are <=
MAX_FUSED_SIZE and actually take the single-block path, so the multi-block
kernels were never exercised. Add (2, 70000) / (1, 3, 70000) shapes that go
through the multi-block path, and correct the misleading comments.

Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant