@@ -95,6 +95,7 @@ def __call__(
9595 mW : cute .Tensor ,
9696 M : Int32 ,
9797 eps : Float32 ,
98+ enable_pdl : cutlass .Constexpr [bool ],
9899 stream ,
99100 ):
100101 tv_shape , tv_stride = make_tv_layout (
@@ -105,11 +106,12 @@ def __call__(
105106 tv_layout = cute .make_layout (tv_shape , stride = tv_stride )
106107 tiler_mn = (1 , self .cols_per_tile )
107108
108- self .kernel (mX , mR , mW , M , eps , tv_layout , tiler_mn ).launch (
109+ self .kernel (mX , mR , mW , M , eps , enable_pdl , tv_layout , tiler_mn ).launch (
109110 grid = [M , 1 , 1 ],
110111 block = [self .num_threads , 1 , 1 ],
111112 smem = self ._smem_size_in_bytes (),
112113 stream = stream ,
114+ use_pdl = enable_pdl ,
113115 )
114116
115117 @cute .kernel
@@ -120,12 +122,17 @@ def kernel(
120122 mW : cute .Tensor ,
121123 M : Int32 ,
122124 eps : Float32 ,
125+ enable_pdl : cutlass .Constexpr [bool ],
123126 tv_layout : cute .Layout ,
124127 tiler_mn : cute .Shape ,
125128 ):
126129 tidx , _ , _ = cute .arch .thread_idx ()
127130 bidx , _ , _ = cute .arch .block_idx ()
128131
132+ # PDL: Wait for previous kernel (SM90+ only)
133+ if enable_pdl :
134+ cute .arch .griddepcontrol_wait ()
135+
129136 H = self .H
130137 weight_bias = self .weight_bias
131138 threads_per_row = tv_layout .shape [0 ][0 ]
@@ -210,6 +217,10 @@ def kernel(
210217
211218 cute .copy (copy_atom , tXrY , tYgX , pred = tXpX )
212219
220+ # PDL: Signal dependent kernels (SM90+ only)
221+ if enable_pdl :
222+ cute .arch .griddepcontrol_launch_dependents ()
223+
213224
214225# =============================================================================
215226# FusedAddRMSNormQuantKernel
@@ -264,6 +275,7 @@ def __call__(
264275 M : Int32 ,
265276 scale : Float32 ,
266277 eps : Float32 ,
278+ enable_pdl : cutlass .Constexpr [bool ],
267279 stream ,
268280 ):
269281 tv_shape , tv_stride = make_tv_layout (
@@ -274,11 +286,14 @@ def __call__(
274286 tv_layout = cute .make_layout (tv_shape , stride = tv_stride )
275287 tiler_mn = (1 , self .cols_per_tile )
276288
277- self .kernel (mY , mX , mR , mW , M , scale , eps , tv_layout , tiler_mn ).launch (
289+ self .kernel (
290+ mY , mX , mR , mW , M , scale , eps , enable_pdl , tv_layout , tiler_mn
291+ ).launch (
278292 grid = [M , 1 , 1 ],
279293 block = [self .num_threads , 1 , 1 ],
280294 smem = self ._smem_size_in_bytes (),
281295 stream = stream ,
296+ use_pdl = enable_pdl ,
282297 )
283298
284299 @cute .kernel
@@ -291,12 +306,17 @@ def kernel(
291306 M : Int32 ,
292307 scale : Float32 ,
293308 eps : Float32 ,
309+ enable_pdl : cutlass .Constexpr [bool ],
294310 tv_layout : cute .Layout ,
295311 tiler_mn : cute .Shape ,
296312 ):
297313 tidx , _ , _ = cute .arch .thread_idx ()
298314 bidx , _ , _ = cute .arch .block_idx ()
299315
316+ # PDL: Wait for previous kernel (SM90+ only)
317+ if enable_pdl :
318+ cute .arch .griddepcontrol_wait ()
319+
300320 H = self .H
301321 weight_bias = self .weight_bias
302322 threads_per_row = tv_layout .shape [0 ][0 ]
@@ -396,14 +416,20 @@ def kernel(
396416 out_ptr = get_ptr_as_int64 (mY , Int32 (out_offset ))
397417 cvt_and_store_f32_to_e4m3 (clamped , out_ptr )
398418
419+ # PDL: Signal dependent kernels (SM90+ only)
420+ if enable_pdl :
421+ cute .arch .griddepcontrol_launch_dependents ()
422+
399423
400424# =============================================================================
401425# Compiled Kernel Getters
402426# =============================================================================
403427
404428
405429@functools .cache
406- def _get_compiled_fused_add_rmsnorm_kernel (dtype_str : str , H : int , weight_bias : float ):
430+ def _get_compiled_fused_add_rmsnorm_kernel (
431+ dtype_str : str , H : int , weight_bias : float , enable_pdl : bool
432+ ):
407433 """Get a compiled Fused Add + RMSNorm kernel using TVM-FFI."""
408434 dtype = get_cutlass_dtype (dtype_str )
409435 kernel_obj = FusedAddRMSNormKernel (dtype , H , weight_bias )
@@ -429,6 +455,7 @@ def _get_compiled_fused_add_rmsnorm_kernel(dtype_str: str, H: int, weight_bias:
429455 w_fake ,
430456 Int32 (1 ),
431457 Float32 (1e-6 ),
458+ enable_pdl ,
432459 stream_fake ,
433460 options = "--enable-tvm-ffi" ,
434461 )
@@ -453,7 +480,7 @@ def tensor_api(
453480
454481@functools .cache
455482def _get_compiled_fused_add_rmsnorm_quant_kernel (
456- dtype_str : str , out_dtype_str : str , H : int , weight_bias : float
483+ dtype_str : str , out_dtype_str : str , H : int , weight_bias : float , enable_pdl : bool
457484):
458485 """Get a compiled Fused Add + RMSNorm + Quant kernel using TVM-FFI."""
459486 dtype = get_cutlass_dtype (dtype_str )
@@ -487,6 +514,7 @@ def _get_compiled_fused_add_rmsnorm_quant_kernel(
487514 Int32 (1 ),
488515 Float32 (1.0 ), # scale
489516 Float32 (1e-6 ),
517+ enable_pdl ,
490518 stream_fake ,
491519 options = "--enable-tvm-ffi" ,
492520 )
@@ -536,7 +564,9 @@ def fused_add_rmsnorm_cute(
536564 M = input .shape [0 ]
537565
538566 dtype_str = _torch_dtype_to_str (input .dtype )
539- kernel = _get_compiled_fused_add_rmsnorm_kernel (dtype_str , H , weight_bias )
567+ kernel = _get_compiled_fused_add_rmsnorm_kernel (
568+ dtype_str , H , weight_bias , enable_pdl
569+ )
540570 kernel (input , residual , weight , M , eps )
541571
542572
@@ -562,7 +592,7 @@ def fused_add_rmsnorm_quant_cute(
562592 dtype_str = _torch_dtype_to_str (input .dtype )
563593 out_dtype_str = _torch_dtype_to_str (out .dtype )
564594 kernel = _get_compiled_fused_add_rmsnorm_quant_kernel (
565- dtype_str , out_dtype_str , H , weight_bias
595+ dtype_str , out_dtype_str , H , weight_bias , enable_pdl
566596 )
567597 kernel (
568598 out ,
0 commit comments