Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions src/liger_kernel/ops/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
from liger_kernel.ops.utils import calculate_settings
from liger_kernel.ops.utils import ensure_contiguous

# Largest single-block row width; rows wider than this use the multi-block
# kernels, which loop over the row in tiles of this size. Matches the
# MAX_FUSED_SIZE cap in calculate_settings().
MAX_FUSED_SIZE = 65536


@triton.jit
def _softmax_single_block_forward_kernel(
Expand Down Expand Up @@ -41,14 +46,14 @@ def _softmax_multi_block_forward_kernel(
row_id = tl.program_id(0)
offs = tl.arange(0, BLOCK_SIZE)

m = tl.float32(-float("inf"))
d = tl.float32(0.0)
m = -float("inf")
d = 0.0
for start in tl.range(0, n_cols, BLOCK_SIZE):
idx = start + offs
mask = idx < n_cols
xblk = tl.load(X_ptr + row_id * X_row_stride + idx, mask=mask, other=-float("inf"), cache_modifier=".ca")
blk_max = tl.max(xblk, axis=0)
new_m = tl.max(m, blk_max)
new_m = tl.maximum(m, blk_max)
d = d * tl.exp(m - new_m) + tl.sum(tl.exp(xblk - new_m), axis=0)
m = new_m

Expand Down Expand Up @@ -95,7 +100,7 @@ def _softmax_multi_block_backward_kernel(
):
row_id = tl.program_id(0)
offs = tl.arange(0, BLOCK_SIZE)
acc = tl.float32(0.0)
acc = 0.0

for start in tl.range(0, n_cols, BLOCK_SIZE):
idx = start + offs
Expand All @@ -118,15 +123,19 @@ def _softmax_forward(x: torch.Tensor) -> Tuple[torch.Tensor, int, int, bool]:
x2d = x.contiguous().view(-1, n_cols)
n_rows = x2d.shape[0]

BLOCK_SIZE, num_warps = calculate_settings(n_cols)
y2d = torch.empty_like(x2d)

if n_cols <= BLOCK_SIZE:
# For rows that fit in a single block, keep the (faster) single-block path.
# Wider rows would make calculate_settings(n_cols) raise, so cap the block
# size and use the multi-block kernel, which loops over the row in tiles.
if n_cols <= MAX_FUSED_SIZE:
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
_softmax_single_block_forward_kernel[(n_rows,)](
y2d, y2d.stride(0), x2d, x2d.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
)
multi_block_launch = False
else:
BLOCK_SIZE, num_warps = calculate_settings(MAX_FUSED_SIZE)
_softmax_multi_block_forward_kernel[(n_rows,)](
y2d, y2d.stride(0), x2d, x2d.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
)
Expand Down
8 changes: 6 additions & 2 deletions test/transformers/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
(4, 16),
(1, 1023), # Large single row single-block dispatch
(3, 7, 256), # 3D input
(1, 4096), # test multi-block dispatch
(1, 2, 4096), # test multi-block dispatch on 3D input
(1, 4096), # single-block dispatch (4096 <= MAX_FUSED_SIZE)
(1, 2, 4096), # single-block dispatch on 3D input
(2, 70000), # > MAX_FUSED_SIZE: exercises the multi-block dispatch
(1, 3, 70000), # multi-block dispatch on 3D input
],
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -68,6 +70,8 @@ def test_liger_softmax(shape, dtype, atol, rtol):
(3, 7, 256),
(1, 4096),
(1, 2, 4096),
(2, 70000), # > MAX_FUSED_SIZE: exercises the multi-block dispatch
(1, 3, 70000),
],
)
@pytest.mark.parametrize(
Expand Down