Skip to content

Commit fcd5c5d

Browse files
committed
add pdl for all
1 parent ba1a645 commit fcd5c5d

3 files changed

Lines changed: 89 additions & 17 deletions

File tree

flashinfer/norm/kernels/fused_add_rmsnorm.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def __call__(
9595
mW: cute.Tensor,
9696
M: Int32,
9797
eps: Float32,
98+
enable_pdl: cutlass.Constexpr[bool],
9899
stream,
99100
):
100101
tv_shape, tv_stride = make_tv_layout(
@@ -105,11 +106,12 @@ def __call__(
105106
tv_layout = cute.make_layout(tv_shape, stride=tv_stride)
106107
tiler_mn = (1, self.cols_per_tile)
107108

108-
self.kernel(mX, mR, mW, M, eps, tv_layout, tiler_mn).launch(
109+
self.kernel(mX, mR, mW, M, eps, enable_pdl, tv_layout, tiler_mn).launch(
109110
grid=[M, 1, 1],
110111
block=[self.num_threads, 1, 1],
111112
smem=self._smem_size_in_bytes(),
112113
stream=stream,
114+
use_pdl=enable_pdl,
113115
)
114116

115117
@cute.kernel
@@ -120,12 +122,17 @@ def kernel(
120122
mW: cute.Tensor,
121123
M: Int32,
122124
eps: Float32,
125+
enable_pdl: cutlass.Constexpr[bool],
123126
tv_layout: cute.Layout,
124127
tiler_mn: cute.Shape,
125128
):
126129
tidx, _, _ = cute.arch.thread_idx()
127130
bidx, _, _ = cute.arch.block_idx()
128131

132+
# PDL: Wait for previous kernel (SM90+ only)
133+
if enable_pdl:
134+
cute.arch.griddepcontrol_wait()
135+
129136
H = self.H
130137
weight_bias = self.weight_bias
131138
threads_per_row = tv_layout.shape[0][0]
@@ -210,6 +217,10 @@ def kernel(
210217

211218
cute.copy(copy_atom, tXrY, tYgX, pred=tXpX)
212219

220+
# PDL: Signal dependent kernels (SM90+ only)
221+
if enable_pdl:
222+
cute.arch.griddepcontrol_launch_dependents()
223+
213224

214225
# =============================================================================
215226
# FusedAddRMSNormQuantKernel
@@ -264,6 +275,7 @@ def __call__(
264275
M: Int32,
265276
scale: Float32,
266277
eps: Float32,
278+
enable_pdl: cutlass.Constexpr[bool],
267279
stream,
268280
):
269281
tv_shape, tv_stride = make_tv_layout(
@@ -274,11 +286,14 @@ def __call__(
274286
tv_layout = cute.make_layout(tv_shape, stride=tv_stride)
275287
tiler_mn = (1, self.cols_per_tile)
276288

277-
self.kernel(mY, mX, mR, mW, M, scale, eps, tv_layout, tiler_mn).launch(
289+
self.kernel(
290+
mY, mX, mR, mW, M, scale, eps, enable_pdl, tv_layout, tiler_mn
291+
).launch(
278292
grid=[M, 1, 1],
279293
block=[self.num_threads, 1, 1],
280294
smem=self._smem_size_in_bytes(),
281295
stream=stream,
296+
use_pdl=enable_pdl,
282297
)
283298

284299
@cute.kernel
@@ -291,12 +306,17 @@ def kernel(
291306
M: Int32,
292307
scale: Float32,
293308
eps: Float32,
309+
enable_pdl: cutlass.Constexpr[bool],
294310
tv_layout: cute.Layout,
295311
tiler_mn: cute.Shape,
296312
):
297313
tidx, _, _ = cute.arch.thread_idx()
298314
bidx, _, _ = cute.arch.block_idx()
299315

316+
# PDL: Wait for previous kernel (SM90+ only)
317+
if enable_pdl:
318+
cute.arch.griddepcontrol_wait()
319+
300320
H = self.H
301321
weight_bias = self.weight_bias
302322
threads_per_row = tv_layout.shape[0][0]
@@ -396,14 +416,20 @@ def kernel(
396416
out_ptr = get_ptr_as_int64(mY, Int32(out_offset))
397417
cvt_and_store_f32_to_e4m3(clamped, out_ptr)
398418

419+
# PDL: Signal dependent kernels (SM90+ only)
420+
if enable_pdl:
421+
cute.arch.griddepcontrol_launch_dependents()
422+
399423

400424
# =============================================================================
401425
# Compiled Kernel Getters
402426
# =============================================================================
403427

404428

405429
@functools.cache
406-
def _get_compiled_fused_add_rmsnorm_kernel(dtype_str: str, H: int, weight_bias: float):
430+
def _get_compiled_fused_add_rmsnorm_kernel(
431+
dtype_str: str, H: int, weight_bias: float, enable_pdl: bool
432+
):
407433
"""Get a compiled Fused Add + RMSNorm kernel using TVM-FFI."""
408434
dtype = get_cutlass_dtype(dtype_str)
409435
kernel_obj = FusedAddRMSNormKernel(dtype, H, weight_bias)
@@ -429,6 +455,7 @@ def _get_compiled_fused_add_rmsnorm_kernel(dtype_str: str, H: int, weight_bias:
429455
w_fake,
430456
Int32(1),
431457
Float32(1e-6),
458+
enable_pdl,
432459
stream_fake,
433460
options="--enable-tvm-ffi",
434461
)
@@ -453,7 +480,7 @@ def tensor_api(
453480

454481
@functools.cache
455482
def _get_compiled_fused_add_rmsnorm_quant_kernel(
456-
dtype_str: str, out_dtype_str: str, H: int, weight_bias: float
483+
dtype_str: str, out_dtype_str: str, H: int, weight_bias: float, enable_pdl: bool
457484
):
458485
"""Get a compiled Fused Add + RMSNorm + Quant kernel using TVM-FFI."""
459486
dtype = get_cutlass_dtype(dtype_str)
@@ -487,6 +514,7 @@ def _get_compiled_fused_add_rmsnorm_quant_kernel(
487514
Int32(1),
488515
Float32(1.0), # scale
489516
Float32(1e-6),
517+
enable_pdl,
490518
stream_fake,
491519
options="--enable-tvm-ffi",
492520
)
@@ -536,7 +564,9 @@ def fused_add_rmsnorm_cute(
536564
M = input.shape[0]
537565

538566
dtype_str = _torch_dtype_to_str(input.dtype)
539-
kernel = _get_compiled_fused_add_rmsnorm_kernel(dtype_str, H, weight_bias)
567+
kernel = _get_compiled_fused_add_rmsnorm_kernel(
568+
dtype_str, H, weight_bias, enable_pdl
569+
)
540570
kernel(input, residual, weight, M, eps)
541571

542572

@@ -562,7 +592,7 @@ def fused_add_rmsnorm_quant_cute(
562592
dtype_str = _torch_dtype_to_str(input.dtype)
563593
out_dtype_str = _torch_dtype_to_str(out.dtype)
564594
kernel = _get_compiled_fused_add_rmsnorm_quant_kernel(
565-
dtype_str, out_dtype_str, H, weight_bias
595+
dtype_str, out_dtype_str, H, weight_bias, enable_pdl
566596
)
567597
kernel(
568598
out,

flashinfer/norm/kernels/layernorm.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def __call__(
108108
mBeta: cute.Tensor,
109109
M: Int32,
110110
eps: Float32,
111+
enable_pdl: cutlass.Constexpr[bool],
111112
stream,
112113
):
113114
# Layout for input (float16/bfloat16)
@@ -135,6 +136,7 @@ def __call__(
135136
mBeta,
136137
M,
137138
eps,
139+
enable_pdl,
138140
tv_layout,
139141
tiler_mn,
140142
tv_layout_f32,
@@ -144,6 +146,7 @@ def __call__(
144146
block=[self.num_threads, 1, 1],
145147
smem=self._smem_size_in_bytes(),
146148
stream=stream,
149+
use_pdl=enable_pdl,
147150
)
148151

149152
@cute.kernel
@@ -155,6 +158,7 @@ def kernel(
155158
mBeta: cute.Tensor,
156159
M: Int32,
157160
eps: Float32,
161+
enable_pdl: cutlass.Constexpr[bool],
158162
tv_layout: cute.Layout,
159163
tiler_mn: cute.Shape,
160164
tv_layout_f32: cute.Layout,
@@ -163,6 +167,10 @@ def kernel(
163167
tidx, _, _ = cute.arch.thread_idx()
164168
bidx, _, _ = cute.arch.block_idx()
165169

170+
# PDL: Wait for previous kernel (SM90+ only)
171+
if enable_pdl:
172+
cute.arch.griddepcontrol_wait()
173+
166174
H = self.H
167175
threads_per_row = tv_layout.shape[0][0]
168176
num_warps = self.num_warps
@@ -343,14 +351,20 @@ def kernel(
343351

344352
cute.copy(copy_atom_load, tXrY, tXgY, pred=tXpX)
345353

354+
# PDL: Signal dependent kernels (SM90+ only)
355+
if enable_pdl:
356+
cute.arch.griddepcontrol_launch_dependents()
357+
346358

347359
# =============================================================================
348360
# Compiled Kernel Getter
349361
# =============================================================================
350362

351363

352364
@functools.cache
353-
def _get_compiled_layernorm_kernel(dtype_str: str, gamma_dtype_str: str, H: int):
365+
def _get_compiled_layernorm_kernel(
366+
dtype_str: str, gamma_dtype_str: str, H: int, enable_pdl: bool
367+
):
354368
"""Get a compiled LayerNorm kernel using TVM-FFI."""
355369
dtype = get_cutlass_dtype(dtype_str)
356370
gamma_dtype = get_cutlass_dtype(gamma_dtype_str)
@@ -383,6 +397,7 @@ def _get_compiled_layernorm_kernel(dtype_str: str, gamma_dtype_str: str, H: int)
383397
beta_fake,
384398
Int32(1),
385399
Float32(1e-6),
400+
enable_pdl,
386401
stream_fake,
387402
options="--enable-tvm-ffi",
388403
)
@@ -418,6 +433,7 @@ def layernorm_cute(
418433
gamma: torch.Tensor,
419434
beta: torch.Tensor,
420435
eps: float = 1e-6,
436+
enable_pdl: bool = False,
421437
) -> None:
422438
"""CuTe DSL LayerNorm implementation.
423439
@@ -430,7 +446,7 @@ def layernorm_cute(
430446

431447
dtype_str = _torch_dtype_to_str(input.dtype)
432448
gamma_dtype_str = _torch_dtype_to_str(gamma.dtype)
433-
kernel = _get_compiled_layernorm_kernel(dtype_str, gamma_dtype_str, H)
449+
kernel = _get_compiled_layernorm_kernel(dtype_str, gamma_dtype_str, H, enable_pdl)
434450
kernel(out, input, gamma, beta, M, eps)
435451

436452

0 commit comments

Comments
 (0)