diff --git a/src/liger_kernel/ops/softmax.py b/src/liger_kernel/ops/softmax.py index 15db6cdda..8fe47e9ab 100644 --- a/src/liger_kernel/ops/softmax.py +++ b/src/liger_kernel/ops/softmax.py @@ -7,6 +7,11 @@ from liger_kernel.ops.utils import calculate_settings from liger_kernel.ops.utils import ensure_contiguous +# Largest single-block row width; rows wider than this use the multi-block +# kernels, which loop over the row in tiles of this size. Matches the +# MAX_FUSED_SIZE cap in calculate_settings(). +MAX_FUSED_SIZE = 65536 + @triton.jit def _softmax_single_block_forward_kernel( @@ -41,14 +46,14 @@ def _softmax_multi_block_forward_kernel( row_id = tl.program_id(0) offs = tl.arange(0, BLOCK_SIZE) - m = tl.float32(-float("inf")) - d = tl.float32(0.0) + m = -float("inf") + d = 0.0 for start in tl.range(0, n_cols, BLOCK_SIZE): idx = start + offs mask = idx < n_cols xblk = tl.load(X_ptr + row_id * X_row_stride + idx, mask=mask, other=-float("inf"), cache_modifier=".ca") blk_max = tl.max(xblk, axis=0) - new_m = tl.max(m, blk_max) + new_m = tl.maximum(m, blk_max) d = d * tl.exp(m - new_m) + tl.sum(tl.exp(xblk - new_m), axis=0) m = new_m @@ -95,7 +100,7 @@ def _softmax_multi_block_backward_kernel( ): row_id = tl.program_id(0) offs = tl.arange(0, BLOCK_SIZE) - acc = tl.float32(0.0) + acc = 0.0 for start in tl.range(0, n_cols, BLOCK_SIZE): idx = start + offs @@ -118,15 +123,19 @@ def _softmax_forward(x: torch.Tensor) -> Tuple[torch.Tensor, int, int, bool]: x2d = x.contiguous().view(-1, n_cols) n_rows = x2d.shape[0] - BLOCK_SIZE, num_warps = calculate_settings(n_cols) y2d = torch.empty_like(x2d) - if n_cols <= BLOCK_SIZE: + # For rows that fit in a single block, keep the (faster) single-block path. + # Wider rows would make calculate_settings(n_cols) raise, so cap the block + # size and use the multi-block kernel, which loops over the row in tiles. + if n_cols <= MAX_FUSED_SIZE: + BLOCK_SIZE, num_warps = calculate_settings(n_cols) _softmax_single_block_forward_kernel[(n_rows,)]( y2d, y2d.stride(0), x2d, x2d.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps ) multi_block_launch = False else: + BLOCK_SIZE, num_warps = calculate_settings(MAX_FUSED_SIZE) _softmax_multi_block_forward_kernel[(n_rows,)]( y2d, y2d.stride(0), x2d, x2d.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps ) diff --git a/test/transformers/test_softmax.py b/test/transformers/test_softmax.py index 6c0666de1..957fe011f 100644 --- a/test/transformers/test_softmax.py +++ b/test/transformers/test_softmax.py @@ -20,8 +20,10 @@ (4, 16), (1, 1023), # Large single row single-block dispatch (3, 7, 256), # 3D input - (1, 4096), # test multi-block dispatch - (1, 2, 4096), # test multi-block dispatch on 3D input + (1, 4096), # single-block dispatch (4096 <= MAX_FUSED_SIZE) + (1, 2, 4096), # single-block dispatch on 3D input + (2, 70000), # > MAX_FUSED_SIZE: exercises the multi-block dispatch + (1, 3, 70000), # multi-block dispatch on 3D input ], ) @pytest.mark.parametrize( @@ -68,6 +70,8 @@ def test_liger_softmax(shape, dtype, atol, rtol): (3, 7, 256), (1, 4096), (1, 2, 4096), + (2, 70000), # > MAX_FUSED_SIZE: exercises the multi-block dispatch + (1, 3, 70000), ], ) @pytest.mark.parametrize(