@@ -296,7 +296,7 @@ def _kernel_matmul_fp8_row(
296
296
SPLIT_K (int): Number of SM's to launch per row.
297
297
USE_BIAS (bool): Whether to use bias.
298
298
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
299
- AB_DTYPE (bool): Wether to cast A and B to C.dtype before tensor core.
299
+ AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
300
300
"""
301
301
# Matrix multiplication.
302
302
start_pid = tl .program_id (axis = 0 )
@@ -459,7 +459,7 @@ def _kernel_matmul_fp8_row_no_fast_acc(
459
459
SPLIT_K (int): Number of SM's to launch per row.
460
460
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
461
461
USE_BIAS(bool): Whether to use bias.
462
- AB_DTYPE (bool): Wether to cast A and B to C.dtype before tensor core.
462
+ AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
463
463
"""
464
464
# Matrix multiplication.
465
465
@@ -615,7 +615,7 @@ def _kernel_matmul_fp8_row_imprecise_acc(
615
615
SPLIT_K (int): Number of SM's to launch per row.
616
616
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
617
617
USE_BIAS (bool): Whether to use bias.
618
- AB_DTYPE (bool): Wether to cast A and B to C.dtype before tensor core.
618
+ AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
619
619
"""
620
620
# Matrix multiplication.
621
621
pid = tl .program_id (0 )
@@ -810,7 +810,7 @@ def _kernel_matmul_fp8_row_tma_persistent(
810
810
GROUP_M (int): Number of groups for M dimension swizzle.
811
811
SPLIT_K (int): Number of SM's to launch per row.
812
812
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
813
- AB_DTYPE (bool): Wether to cast A and B to C.dtype before tensor core.
813
+ AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
814
814
"""
815
815
# Matrix multiplication.
816
816
start_pid = tl .program_id (axis = 0 )
@@ -1050,7 +1050,7 @@ def _kernel_matmul_fp8_row_tma_persistent_ws_cooperative(
1050
1050
GROUP_M (int): Number of groups for M dimension swizzle.
1051
1051
SPLIT_K (int): Number of SM's to launch per row.
1052
1052
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
1053
- AB_DTYPE (bool): Wether to cast A and B to C.dtype before tensor core.
1053
+ AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
1054
1054
"""
1055
1055
num_tiles = tl .cdiv (M , BLOCK_M ) * tl .cdiv (N , BLOCK_N )
1056
1056
num_pid_m = tl .cdiv (M , BLOCK_M )
@@ -1206,8 +1206,6 @@ def persistent_grid(META):
1206
1206
1207
1207
if no_use_persistent :
1208
1208
logger .info ("Using non-persistent kernel" )
1209
- if bias is not None :
1210
- raise AssertionError ("bias is not supported in non-persistent kernel" )
1211
1209
# pyre-ignore
1212
1210
torch ._library .capture_triton (_kernel_matmul_fp8_row_non_persistent )[grid ](
1213
1211
a ,
@@ -1221,7 +1219,7 @@ def persistent_grid(META):
1221
1219
k_key ,
1222
1220
a_scale ,
1223
1221
b_scale ,
1224
- # bias,
1222
+ bias ,
1225
1223
a .stride (0 ),
1226
1224
a .stride (1 ),
1227
1225
b .stride (0 ),
@@ -1232,7 +1230,7 @@ def persistent_grid(META):
1232
1230
allow_tf32 = allow_tf32 ,
1233
1231
fp8_fast_accum = fp8_fast_accum ,
1234
1232
# GROUP_M=8,
1235
- # USE_BIAS=bias is not None,
1233
+ USE_BIAS = bias is not None ,
1236
1234
AB_DTYPE = False ,
1237
1235
)
1238
1236
elif use_warp_specialization :
@@ -1679,7 +1677,7 @@ def _kernel_matmul_fp8_block_fastacc(
1679
1677
GROUP_M (int): Number of groups for M dimension swizzle.
1680
1678
SPLIT_K (int): Number of SM's to launch per row.
1681
1679
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
1682
- AB_DTYPE (bool): Wether to cast A and B to C.dtype before tensor core.
1680
+ AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
1683
1681
"""
1684
1682
assert BLOCK_M < scale_block_m
1685
1683
assert BLOCK_N < scale_block_n
@@ -1875,7 +1873,7 @@ def _kernel_matmul_fp8_block_slowacc(
1875
1873
GROUP_M (int): Number of groups for M dimension swizzle.
1876
1874
SPLIT_K (int): Number of SM's to launch per row.
1877
1875
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
1878
- AB_DTYPE (bool): Wether to cast A and B to C.dtype before tensor core.
1876
+ AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
1879
1877
"""
1880
1878
assert BLOCK_M < scale_block_m
1881
1879
assert BLOCK_N < scale_block_n
@@ -3172,6 +3170,7 @@ def prune_configs(configs, named_args, **kwargs):
3172
3170
K = named_args ["K" ]
3173
3171
elemBytes_a = named_args ["A" ].element_size ()
3174
3172
elemBytes_b = named_args ["B" ].element_size ()
3173
+ use_bias = kwargs ["USE_BIAS" ]
3175
3174
3176
3175
if M < 32 or N < 32 :
3177
3176
mfma = 16
@@ -3211,6 +3210,9 @@ def prune_configs(configs, named_args, **kwargs):
3211
3210
continue
3212
3211
if BLOCK_SIZE_N > N * 2 and BLOCK_SIZE_N != 16 :
3213
3212
continue
3213
+ # split_k cannot be used if there is a bias
3214
+ if use_bias and SPLIT_K != 1 :
3215
+ continue
3214
3216
# skip large split_k when not necessary
3215
3217
if SPLIT_K != 1 and not need_split_k (M , N , K ):
3216
3218
continue
@@ -3369,6 +3371,7 @@ def _kernel_matmul_fp8_row_non_persistent(
3369
3371
k_key ,
3370
3372
A_scale ,
3371
3373
B_scale ,
3374
+ Bias ,
3372
3375
stride_am ,
3373
3376
stride_ak ,
3374
3377
stride_bn ,
@@ -3384,6 +3387,7 @@ def _kernel_matmul_fp8_row_non_persistent(
3384
3387
GROUP_M : tl .constexpr ,
3385
3388
SPLIT_K : tl .constexpr ,
3386
3389
EVEN_K : tl .constexpr ,
3390
+ USE_BIAS : tl .constexpr ,
3387
3391
AB_DTYPE : tl .constexpr ,
3388
3392
) -> None :
3389
3393
"""Matmul kernel of [M, K] @ [N, K] with row-wise scales
@@ -3402,6 +3406,7 @@ def _kernel_matmul_fp8_row_non_persistent(
3402
3406
k_key (int): Autotuning key for K dimension of input tensor.
3403
3407
A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A
3404
3408
B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B
3409
+ Bias (tensorWrapper): [N] Optional bias tensor.
3405
3410
stride_am (int): Stride of M dimension of A.
3406
3411
stride_ak (int): Stride of K dimension of A.
3407
3412
stride_bn (int): Stride of N dimension of B.
@@ -3417,7 +3422,8 @@ def _kernel_matmul_fp8_row_non_persistent(
3417
3422
GROUP_M (int): Number of groups for M dimension swizzle.
3418
3423
SPLIT_K (int): Number of SM's to launch per row.
3419
3424
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
3420
- AB_DTYPE (bool): Wether to cast A and B to C.dtype before tensor core.
3425
+ USE_BIAS (bool): Whether to use bias.
3426
+ AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
3421
3427
"""
3422
3428
tl .assume (M >= 0 )
3423
3429
tl .assume (N >= 0 )
@@ -3484,6 +3490,11 @@ def _kernel_matmul_fp8_row_non_persistent(
3484
3490
scale = a_scale [:, None ] * b_scale [None , :]
3485
3491
acc *= scale
3486
3492
3493
+ # Load and add bias if specified.
3494
+ if USE_BIAS :
3495
+ bias = tl .load (Bias + rn , mask = rn < N )
3496
+ acc += bias [None , :]
3497
+
3487
3498
acc = acc .to (C .dtype .element_ty )
3488
3499
C = C + (rm [:, None ] * stride_cm + rn [None , :] * stride_cn )
3489
3500
mask = (rm < M )[:, None ] & (rn < N )[None , :]
0 commit comments