Skip to content

Commit b2d9ec4

Browse files
authored
[BENCH] fix swiglu routing and simplify args (#6743)
bench_mlp.py was wrong since it did not pass num_experts to the swiglu. This arg defaulted to zero, causing the kernel to load the wrong slot in the expert data. Fix this, and remove the num_experts arg so this is more foolproof. We can get the number of experts from the routing data.
1 parent 81f93f2 commit b2d9ec4

3 files changed

Lines changed: 11 additions & 12 deletions

File tree

bench/tests/test_swiglu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,6 @@ def test_op(M, N, limit, device, alpha=0.5):
3737
# initialize data
3838
x = alloc_rand([n_tokens, N], device=device, dtype=torch.bfloat16)
3939
precision_config = PrecisionConfig(limit=limit)
40-
tri_y = swiglu(x, alpha, precision_config, routing_data, n_expts_tot)
40+
tri_y = swiglu(x, alpha, precision_config, routing_data)
4141
ref_y = swiglu_torch(x, alpha, precision_config)
4242
assert_close(tri_y, ref_y)

bench/triton_bench/swiglu.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class PrecisionConfig:
2323
class SwiGLU(torch.autograd.Function):
2424

2525
@staticmethod
26-
def forward(ctx, a, alpha, precision_config, routing_data, num_experts):
26+
def forward(ctx, a, alpha, precision_config, routing_data):
2727
N = a.shape[-1]
2828
M = a.numel() // N
2929
assert a.stride()[-1] == 1
@@ -48,9 +48,9 @@ def forward(ctx, a, alpha, precision_config, routing_data, num_experts):
4848
grid = (8 * num_sms, )
4949
else:
5050
grid = (min(M_BLOCKS * N_BLOCKS, 4 * num_sms), )
51-
expt_data = None
51+
n_tokens = None
5252
if routing_data is not None:
53-
expt_data = compute_metadata(routing_data, M, BLOCK_M).buffer
53+
n_tokens = compute_metadata(routing_data, M, BLOCK_M).offs[routing_data.n_expts_tot]
5454
_swiglu[grid](
5555
flex_ctx.out_data.reinterpret(out),
5656
flex_ctx.out_data.expected_scale,
@@ -66,8 +66,7 @@ def forward(ctx, a, alpha, precision_config, routing_data, num_experts):
6666
out.shape[-1],
6767
1,
6868
precision_config.limit,
69-
expt_data,
70-
num_experts,
69+
n_tokens,
7170
BLOCK_M=BLOCK_M,
7271
BLOCK_N=BLOCK_N,
7372
EVEN_N=(N // 2) % 2 == 0,
@@ -81,8 +80,8 @@ def forward(ctx, a, alpha, precision_config, routing_data, num_experts):
8180
return out
8281

8382

84-
def swiglu(a, alpha, precision_config, routing_data=None, num_experts=0):
85-
return SwiGLU.apply(a, alpha, precision_config, routing_data, num_experts)
83+
def swiglu(a, alpha, precision_config, routing_data=None):
84+
return SwiGLU.apply(a, alpha, precision_config, routing_data)
8685

8786

8887
def swiglu_torch(a, alpha, precision_config):

bench/triton_bench/swiglu_details/_swiglu.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ def swiglu_launch_metadata(grid, kernel, args):
3636

3737
@triton.jit(repr=swiglu_repr, launch_metadata=swiglu_launch_metadata)
3838
def _swiglu(Out, OutExpectedScale, OutActualScale, OutChecksumScale, A, AScale, alpha, M, N, stride_am, stride_an,
39-
stride_outm, stride_outn, limit: tl.constexpr, ExptData, NUM_EXPERTS: tl.constexpr, BLOCK_M: tl.constexpr,
40-
BLOCK_N: tl.constexpr, EVEN_N: tl.constexpr, M_BLOCKS, N_BLOCKS, flexpoint_saturate_inf: tl.constexpr):
41-
if ExptData is not None:
42-
M = tl.load(ExptData + 2 * NUM_EXPERTS)
39+
stride_outm, stride_outn, limit: tl.constexpr, NTokens, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
40+
EVEN_N: tl.constexpr, M_BLOCKS, N_BLOCKS, flexpoint_saturate_inf: tl.constexpr):
41+
if NTokens is not None:
42+
M = tl.load(NTokens)
4343
M_BLOCKS = (M + BLOCK_M - 1) // BLOCK_M
4444

4545
local_max = tl.full([tl.extra.cuda.num_threads()], 0.0, tl.float32)

0 commit comments

Comments
 (0)