Skip to content

Commit 98cf949

Browse files
njriasanfacebook-github-bot
authored andcommitted
Support Bias in _kernel_matmul_fp8_row_non_persistent (#4167)
Summary: Pull Request resolved: #4167 X-link: facebookresearch/FBGEMM#1247 Matches Bias support with the persistent kernel. Reviewed By: karthik-man Differential Revision: D74914575 fbshipit-source-id: bff73c9d13ac193085b7aee927b9701e94cdd4e7
1 parent c845cc9 commit 98cf949

File tree

1 file changed

+23
-12
lines changed

1 file changed

+23
-12
lines changed

fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def _kernel_matmul_fp8_row(
296296
SPLIT_K (int): Number of SM's to launch per row.
297297
USE_BIAS (bool): Whether to use bias.
298298
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
299-
AB_DTYPE (bool): Wether to cast A and B to C.dtype before tensor core.
299+
AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
300300
"""
301301
# Matrix multiplication.
302302
start_pid = tl.program_id(axis=0)
@@ -459,7 +459,7 @@ def _kernel_matmul_fp8_row_no_fast_acc(
459459
SPLIT_K (int): Number of SM's to launch per row.
460460
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
461461
USE_BIAS(bool): Whether to use bias.
462-
AB_DTYPE (bool): Wether to cast A and B to C.dtype before tensor core.
462+
AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
463463
"""
464464
# Matrix multiplication.
465465

@@ -615,7 +615,7 @@ def _kernel_matmul_fp8_row_imprecise_acc(
615615
SPLIT_K (int): Number of SM's to launch per row.
616616
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
617617
USE_BIAS (bool): Whether to use bias.
618-
AB_DTYPE (bool): Wether to cast A and B to C.dtype before tensor core.
618+
AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
619619
"""
620620
# Matrix multiplication.
621621
pid = tl.program_id(0)
@@ -810,7 +810,7 @@ def _kernel_matmul_fp8_row_tma_persistent(
810810
GROUP_M (int): Number of groups for M dimension swizzle.
811811
SPLIT_K (int): Number of SM's to launch per row.
812812
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
813-
AB_DTYPE (bool): Wether to cast A and B to C.dtype before tensor core.
813+
AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
814814
"""
815815
# Matrix multiplication.
816816
start_pid = tl.program_id(axis=0)
@@ -1050,7 +1050,7 @@ def _kernel_matmul_fp8_row_tma_persistent_ws_cooperative(
10501050
GROUP_M (int): Number of groups for M dimension swizzle.
10511051
SPLIT_K (int): Number of SM's to launch per row.
10521052
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
1053-
AB_DTYPE (bool): Wether to cast A and B to C.dtype before tensor core.
1053+
AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
10541054
"""
10551055
num_tiles = tl.cdiv(M, BLOCK_M) * tl.cdiv(N, BLOCK_N)
10561056
num_pid_m = tl.cdiv(M, BLOCK_M)
@@ -1206,8 +1206,6 @@ def persistent_grid(META):
12061206

12071207
if no_use_persistent:
12081208
logger.info("Using non-persistent kernel")
1209-
if bias is not None:
1210-
raise AssertionError("bias is not supported in non-persistent kernel")
12111209
# pyre-ignore
12121210
torch._library.capture_triton(_kernel_matmul_fp8_row_non_persistent)[grid](
12131211
a,
@@ -1221,7 +1219,7 @@ def persistent_grid(META):
12211219
k_key,
12221220
a_scale,
12231221
b_scale,
1224-
# bias,
1222+
bias,
12251223
a.stride(0),
12261224
a.stride(1),
12271225
b.stride(0),
@@ -1232,7 +1230,7 @@ def persistent_grid(META):
12321230
allow_tf32=allow_tf32,
12331231
fp8_fast_accum=fp8_fast_accum,
12341232
# GROUP_M=8,
1235-
# USE_BIAS=bias is not None,
1233+
USE_BIAS=bias is not None,
12361234
AB_DTYPE=False,
12371235
)
12381236
elif use_warp_specialization:
@@ -1679,7 +1677,7 @@ def _kernel_matmul_fp8_block_fastacc(
16791677
GROUP_M (int): Number of groups for M dimension swizzle.
16801678
SPLIT_K (int): Number of SM's to launch per row.
16811679
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
1682-
AB_DTYPE (bool): Wether to cast A and B to C.dtype before tensor core.
1680+
AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
16831681
"""
16841682
assert BLOCK_M < scale_block_m
16851683
assert BLOCK_N < scale_block_n
@@ -1875,7 +1873,7 @@ def _kernel_matmul_fp8_block_slowacc(
18751873
GROUP_M (int): Number of groups for M dimension swizzle.
18761874
SPLIT_K (int): Number of SM's to launch per row.
18771875
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
1878-
AB_DTYPE (bool): Wether to cast A and B to C.dtype before tensor core.
1876+
AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
18791877
"""
18801878
assert BLOCK_M < scale_block_m
18811879
assert BLOCK_N < scale_block_n
@@ -3172,6 +3170,7 @@ def prune_configs(configs, named_args, **kwargs):
31723170
K = named_args["K"]
31733171
elemBytes_a = named_args["A"].element_size()
31743172
elemBytes_b = named_args["B"].element_size()
3173+
use_bias = kwargs["USE_BIAS"]
31753174

31763175
if M < 32 or N < 32:
31773176
mfma = 16
@@ -3211,6 +3210,9 @@ def prune_configs(configs, named_args, **kwargs):
32113210
continue
32123211
if BLOCK_SIZE_N > N * 2 and BLOCK_SIZE_N != 16:
32133212
continue
3213+
# split_k cannot be used if there is a bias
3214+
if use_bias and SPLIT_K != 1:
3215+
continue
32143216
# skip large split_k when not necessary
32153217
if SPLIT_K != 1 and not need_split_k(M, N, K):
32163218
continue
@@ -3369,6 +3371,7 @@ def _kernel_matmul_fp8_row_non_persistent(
33693371
k_key,
33703372
A_scale,
33713373
B_scale,
3374+
Bias,
33723375
stride_am,
33733376
stride_ak,
33743377
stride_bn,
@@ -3384,6 +3387,7 @@ def _kernel_matmul_fp8_row_non_persistent(
33843387
GROUP_M: tl.constexpr,
33853388
SPLIT_K: tl.constexpr,
33863389
EVEN_K: tl.constexpr,
3390+
USE_BIAS: tl.constexpr,
33873391
AB_DTYPE: tl.constexpr,
33883392
) -> None:
33893393
"""Matmul kernel of [M, K] @ [N, K] with row-wise scales
@@ -3402,6 +3406,7 @@ def _kernel_matmul_fp8_row_non_persistent(
34023406
k_key (int): Autotuning key for K dimension of input tensor.
34033407
A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A
34043408
B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B
3409+
Bias (tensorWrapper): [N] Optional bias tensor.
34053410
stride_am (int): Stride of M dimension of A.
34063411
stride_ak (int): Stride of K dimension of A.
34073412
stride_bn (int): Stride of N dimension of B.
@@ -3417,7 +3422,8 @@ def _kernel_matmul_fp8_row_non_persistent(
34173422
GROUP_M (int): Number of groups for M dimension swizzle.
34183423
SPLIT_K (int): Number of SM's to launch per row.
34193424
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
3420-
AB_DTYPE (bool): Wether to cast A and B to C.dtype before tensor core.
3425+
USE_BIAS (bool): Whether to use bias.
3426+
AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
34213427
"""
34223428
tl.assume(M >= 0)
34233429
tl.assume(N >= 0)
@@ -3484,6 +3490,11 @@ def _kernel_matmul_fp8_row_non_persistent(
34843490
scale = a_scale[:, None] * b_scale[None, :]
34853491
acc *= scale
34863492

3493+
# Load and add bias if specified.
3494+
if USE_BIAS:
3495+
bias = tl.load(Bias + rn, mask=rn < N)
3496+
acc += bias[None, :]
3497+
34873498
acc = acc.to(C.dtype.element_ty)
34883499
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
34893500
mask = (rm < M)[:, None] & (rn < N)[None, :]

0 commit comments

Comments
 (0)