From eb5ef81c2be98dcaa7fe448f125d43f1f482ba32 Mon Sep 17 00:00:00 2001 From: Prathamesh Jadhav <55660103+lollinng@users.noreply.github.com> Date: Fri, 5 Jun 2026 02:42:41 +0530 Subject: [PATCH 1/2] Fix softmax multi-block path so wide rows (n_cols > 65536) work The multi-block softmax path was unreachable and broken: - _softmax_forward used BLOCK_SIZE = calculate_settings(n_cols) = next_power_of_2(n_cols) >= n_cols, so 'n_cols <= BLOCK_SIZE' was always true and the multi-block kernel never ran; for n_cols > 65536 calculate_settings raised instead. - the multi-block kernels used tl.float32(-inf)/tl.float32(0.0) (not valid value constructors) and tl.max(m, blk_max) (a reduction) where elementwise tl.maximum was meant, so they would not even compile. Dispatch to the single-block kernel for n_cols <= MAX_FUSED_SIZE (unchanged) and to the multi-block kernel, capped to a MAX_FUSED_SIZE tile, for wider rows. Fix the kernel initializers and the max op. Co-authored-by: Claude Opus 4.8 (1M context) --- src/liger_kernel/ops/softmax.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/liger_kernel/ops/softmax.py b/src/liger_kernel/ops/softmax.py index 15db6cdda..8fe47e9ab 100644 --- a/src/liger_kernel/ops/softmax.py +++ b/src/liger_kernel/ops/softmax.py @@ -7,6 +7,11 @@ from liger_kernel.ops.utils import calculate_settings from liger_kernel.ops.utils import ensure_contiguous +# Largest single-block row width; rows wider than this use the multi-block +# kernels, which loop over the row in tiles of this size. Matches the +# MAX_FUSED_SIZE cap in calculate_settings(). +MAX_FUSED_SIZE = 65536 + @triton.jit def _softmax_single_block_forward_kernel( @@ -41,14 +46,14 @@ def _softmax_multi_block_forward_kernel( row_id = tl.program_id(0) offs = tl.arange(0, BLOCK_SIZE) - m = tl.float32(-float("inf")) - d = tl.float32(0.0) + m = -float("inf") + d = 0.0 for start in tl.range(0, n_cols, BLOCK_SIZE): idx = start + offs mask = idx < n_cols xblk = tl.load(X_ptr + row_id * X_row_stride + idx, mask=mask, other=-float("inf"), cache_modifier=".ca") blk_max = tl.max(xblk, axis=0) - new_m = tl.max(m, blk_max) + new_m = tl.maximum(m, blk_max) d = d * tl.exp(m - new_m) + tl.sum(tl.exp(xblk - new_m), axis=0) m = new_m @@ -95,7 +100,7 @@ def _softmax_multi_block_backward_kernel( ): row_id = tl.program_id(0) offs = tl.arange(0, BLOCK_SIZE) - acc = tl.float32(0.0) + acc = 0.0 for start in tl.range(0, n_cols, BLOCK_SIZE): idx = start + offs @@ -118,15 +123,19 @@ def _softmax_forward(x: torch.Tensor) -> Tuple[torch.Tensor, int, int, bool]: x2d = x.contiguous().view(-1, n_cols) n_rows = x2d.shape[0] - BLOCK_SIZE, num_warps = calculate_settings(n_cols) y2d = torch.empty_like(x2d) - if n_cols <= BLOCK_SIZE: + # For rows that fit in a single block, keep the (faster) single-block path. + # Wider rows would make calculate_settings(n_cols) raise, so cap the block + # size and use the multi-block kernel, which loops over the row in tiles. + if n_cols <= MAX_FUSED_SIZE: + BLOCK_SIZE, num_warps = calculate_settings(n_cols) _softmax_single_block_forward_kernel[(n_rows,)]( y2d, y2d.stride(0), x2d, x2d.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps ) multi_block_launch = False else: + BLOCK_SIZE, num_warps = calculate_settings(MAX_FUSED_SIZE) _softmax_multi_block_forward_kernel[(n_rows,)]( y2d, y2d.stride(0), x2d, x2d.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps ) From 4b9b2c5277dd15079d804e269095be69818a3ce9 Mon Sep 17 00:00:00 2001 From: Prathamesh Jadhav <55660103+lollinng@users.noreply.github.com> Date: Fri, 5 Jun 2026 02:45:25 +0530 Subject: [PATCH 2/2] test: exercise softmax multi-block path with n_cols > MAX_FUSED_SIZE The existing 'multi-block dispatch' test shapes (e.g. (1, 4096)) are <= MAX_FUSED_SIZE and actually take the single-block path, so the multi-block kernels were never exercised. Add (2, 70000) / (1, 3, 70000) shapes that go through the multi-block path, and correct the misleading comments. Co-authored-by: Claude Opus 4.8 (1M context) --- test/transformers/test_softmax.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/transformers/test_softmax.py b/test/transformers/test_softmax.py index 6c0666de1..957fe011f 100644 --- a/test/transformers/test_softmax.py +++ b/test/transformers/test_softmax.py @@ -20,8 +20,10 @@ (4, 16), (1, 1023), # Large single row single-block dispatch (3, 7, 256), # 3D input - (1, 4096), # test multi-block dispatch - (1, 2, 4096), # test multi-block dispatch on 3D input + (1, 4096), # single-block dispatch (4096 <= MAX_FUSED_SIZE) + (1, 2, 4096), # single-block dispatch on 3D input + (2, 70000), # > MAX_FUSED_SIZE: exercises the multi-block dispatch + (1, 3, 70000), # multi-block dispatch on 3D input ], ) @pytest.mark.parametrize( @@ -68,6 +70,8 @@ def test_liger_softmax(shape, dtype, atol, rtol): (3, 7, 256), (1, 4096), (1, 2, 4096), + (2, 70000), # > MAX_FUSED_SIZE: exercises the multi-block dispatch + (1, 3, 70000), ], ) @pytest.mark.parametrize(