Skip to content

Commit f2895fa

Browse files
authored
[Kernels] Enable persistent matmul for fp32 inputs (#9393)
1 parent 69d5bc2 commit f2895fa

5 files changed

Lines changed: 22 additions & 4 deletions

File tree

python/triton_kernels/triton_kernels/matmul_details/_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
@triton.constexpr_function
1010
def get_scaled_dot_format_string(dtype: tl.dtype):
1111
mapping = {
12+
tl.float32: "fp32",
1213
tl.float16: "fp16",
1314
tl.bfloat16: "bf16",
1415
tl.uint8: "e2m1",

python/triton_kernels/triton_kernels/matmul_details/_matmul.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020

2121
@triton.jit
2222
def round_f32_to_tf32(x: tl.tensor):
23-
ASM: tl.constexpr = "cvt.rna.tf32.f32 $0, $1;"
23+
# use cvt.rn on Hopper+ to match the rounding of TMA.
24+
ASM: tl.constexpr = "cvt.rn.tf32.f32 $0, $1;" if cuda_capability_geq(9, 0) else "cvt.rna.tf32.f32 $0, $1;"
2425
return tl.inline_asm_elementwise(ASM, "=r, r", [x], dtype=tl.float32, is_pure=True, pack=1)
2526

2627
_matmul_repr = make_matmul_repr("_matmul", [0, 1, 2])

python/triton_kernels/triton_kernels/matmul_details/_p_matmul.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ def _load_writeback_idx_and_mask(WriteBackIndx, writeback_size, offs, mask):
4646
return (offs, mask)
4747

4848

49+
@triton.jit
50+
def round_f32_to_tf32(x: tl.tensor):
51+
ASM: tl.constexpr = "cvt.rn.tf32.f32 $0, $1;" if cuda_capability_geq(9, 0) else "cvt.rna.tf32.f32 $0, $1;"
52+
return tl.inline_asm_elementwise(ASM, "=r, r", [x], dtype=tl.float32, is_pure=True, pack=1)
53+
54+
4955
_matmul_repr = make_matmul_repr("_p_matmul", [0, 1, 2])
5056
@triton.jit(do_not_specialize=["TOKENS_PER_EXPT_FOR_ANNOTATION"],
5157
repr=_matmul_repr, launch_metadata=matmul_launch_metadata)
@@ -312,7 +318,9 @@ def _p_matmul(
312318
x = tl.load(XPtrs)
313319
else:
314320
x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0)
315-
321+
if x.dtype == tl.float32 and ALLOW_TF32:
322+
# since data are not loaded from TMA we need to explicitly round to tf32.
323+
x = round_f32_to_tf32(x)
316324
# --- load x_scale ---
317325
x_format: tl.constexpr = get_scaled_dot_format_string(x.dtype)
318326
if is_x_microscaled:

python/triton_kernels/triton_kernels/matmul_details/opt_flags.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def _is_layout_strided(layout: Layout | None) -> bool:
248248
is_persistent = True
249249
else:
250250
has_simple_epilogue = precision_config.max_num_imprecise_acc is None
251-
is_persistent = supports_persistent and has_simple_epilogue and (tiles_per_sm >= 2.0 or lhs_dtype.bitwidth <= 8) and out_dtype.bitwidth < 32
251+
is_persistent = supports_persistent and has_simple_epilogue and (tiles_per_sm >= 2.0 or lhs_dtype.bitwidth <= 8) and (out_dtype.bitwidth < 32 or lhs_dtype.bitwidth == 32 or rhs_dtype.bitwidth == 32)
252252
# TMA is slower for batched matmuls with small m/n/k.
253253
if m * n * k < 131072:
254254
is_persistent = False

python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_nvidia.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import triton
33
from triton_kernels import target_info
44
from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
5-
from triton_kernels.tensor import FP4, Tensor, FP16, BF16
5+
from triton_kernels.tensor import FP4, FP16, FP32, BF16, Tensor
66
from triton_kernels.tensor_details.layout import HopperMXScaleLayout
77
from triton_kernels.tensor_details.layout_details.blackwell_scale import BlackwellActMXScaleLayout
88

@@ -143,7 +143,15 @@ def compute_num_stages(
143143
if x_transpose:
144144
smem_capacity -= block_m * block_k * (max(8, lhs_dtype.bitwidth) // 8)
145145

146+
# Persistent fp32 kernels need extra smem headroom (metadata/barriers/TMA state)
147+
# that is not fully captured by the simple stage_size model above.
148+
if is_persistent and (lhs_dtype == FP32 or rhs_dtype == FP32):
149+
smem_capacity -= 32 * 1024
150+
smem_capacity = max(smem_capacity, 0)
146151
num_stages = min(smem_capacity // int(stage_size), 4)
152+
# Keep one stage of headroom for persistent fp32 to avoid launch-time OOR.
153+
if is_persistent and (lhs_dtype == FP32 or rhs_dtype == FP32):
154+
num_stages = min(num_stages, 3)
147155
if num_stages == 0:
148156
num_stages = 1
149157
return num_stages

0 commit comments

Comments
 (0)