5151ENABLE_FULL_TURNING_SPACE = False
5252
5353
54+ def _use_meta_ws () -> bool :
55+ """Check if Meta's warp specialization is available, enabled, and on SM100+."""
56+ return (
57+ is_sm100_plus ()
58+ and hasattr (triton , "knobs" )
59+ and hasattr (triton .knobs , "nvidia" )
60+ and triton .knobs .nvidia .use_meta_ws
61+ )
62+
63+
5464def _check_tma_alignment (
5565 x : torch .Tensor , w : torch .Tensor , y : torch .Tensor , min_alignment : int = 16
5666) -> bool :
@@ -81,6 +91,20 @@ def _check_tma_alignment(
8191 return (K % min_alignment == 0 ) and (N % min_alignment == 0 )
8292
8393
94+ def _prune_persistent_autows_configs (configs , named_args , ** kwargs ): # noqa
95+ if not _use_meta_ws ():
96+ return configs
97+ pruned = []
98+ for c in configs :
99+ BLOCK_M = c .kwargs .get ("BLOCK_M" , 0 )
100+ DP = c .kwargs .get ("DATA_PARTITION_FACTOR" , 1 )
101+ # DATA_PARTITION_FACTOR=2 is only supported with BLOCK_M=256
102+ if DP == 2 and BLOCK_M != 256 :
103+ continue
104+ pruned .append (c )
105+ return pruned
106+
107+
84108def _prune_configs_for_tlx_persistent_addmm (configs , named_args , ** kwargs ): # noqa
85109 M = named_args .get ("M" , 0 )
86110 N = named_args .get ("N" , 0 )
@@ -405,6 +429,39 @@ def _get_addmm_tma_ws_persistent_configs(pre_hook=None) -> List[triton.Config]:
405429 return configs
406430
407431
432+ def get_triton_persistent_configs (pre_hook = None ) -> List [triton .Config ]:
433+ if not _use_meta_ws ():
434+ configs = get_mm_configs (pre_hook = pre_hook )
435+ for c in configs :
436+ c .kwargs ["DATA_PARTITION_FACTOR" ] = 1
437+ c .kwargs ["EPILOGUE_SUBTILE" ] = 1
438+ return configs
439+ # TODO: Prune configs to best configs.
440+ return [
441+ triton .Config ( # pyre-ignore[28]
442+ {
443+ "BLOCK_M" : block_m ,
444+ "BLOCK_N" : block_n ,
445+ "BLOCK_K" : block_k ,
446+ "GROUP_M" : 8 ,
447+ "EPILOGUE_SUBTILE" : subtile ,
448+ "DATA_PARTITION_FACTOR" : DP ,
449+ },
450+ num_stages = num_stages ,
451+ num_warps = 4 ,
452+ pre_hook = pre_hook ,
453+ early_tma_store_lowering = 1 ,
454+ maxRegAutoWS = 255 ,
455+ )
456+ for block_m in [64 , 128 , 256 ]
457+ for block_n in [64 , 128 , 256 ]
458+ for block_k in [64 , 128 , 256 ]
459+ for num_stages in [2 , 3 , 4 ]
460+ for subtile in [1 , 2 , 4 ]
461+ for DP in [1 , 2 ]
462+ ]
463+
464+
408465@triton_cc (
409466 annotations = {
410467 "M" : "i32" ,
@@ -527,9 +584,104 @@ def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS):
527584 return pid_m , pid_n
528585
529586
587+ @triton .jit
588+ def _addmm_persistent_tile_body (
589+ x_desc ,
590+ w_desc ,
591+ y_desc ,
592+ z_desc ,
593+ tile_id ,
594+ num_pid_in_group ,
595+ num_pid_m ,
596+ k_tiles ,
597+ BLOCK_M : tl .constexpr ,
598+ BLOCK_N : tl .constexpr ,
599+ BLOCK_K : tl .constexpr ,
600+ GROUP_M : tl .constexpr ,
601+ ALLOW_TF32 : tl .constexpr ,
602+ BROADCAST_Y : tl .constexpr ,
603+ NUM_SMS : tl .constexpr ,
604+ EPILOGUE_SUBTILE : tl .constexpr ,
605+ INNER_WARP_SPECIALIZE : tl .constexpr ,
606+ ):
607+ pid_m , pid_n = _compute_pid (tile_id , num_pid_in_group , num_pid_m , GROUP_M , NUM_SMS )
608+ offs_xm = pid_m * BLOCK_M
609+ offs_wn = pid_n * BLOCK_N
610+
611+ accumulator = tl .zeros ((BLOCK_M , BLOCK_N ), dtype = tl .float32 )
612+ for k in tl .range (0 , k_tiles , warp_specialize = INNER_WARP_SPECIALIZE ):
613+ offs_k = k * BLOCK_K
614+ x = x_desc .load ([offs_xm , offs_k ])
615+ w = w_desc .load ([offs_k , offs_wn ])
616+ accumulator = tl .dot (x , w , accumulator , allow_tf32 = ALLOW_TF32 )
617+
618+ # Epilogue subtiling breaks the store into multiple pieces to reduce
619+ # shared memory consumption and allow higher stage counts.
620+ if EPILOGUE_SUBTILE == 1 :
621+ if BROADCAST_Y :
622+ y = y_desc .load ([0 , offs_wn ])
623+ else :
624+ y = y_desc .load ([offs_xm , offs_wn ])
625+ z = (accumulator + y .to (tl .float32 )).to (z_desc .dtype )
626+ z_desc .store ([offs_xm , offs_wn ], z )
627+ elif EPILOGUE_SUBTILE == 2 :
628+ acc = tl .reshape (accumulator , (BLOCK_M , 2 , BLOCK_N // 2 ))
629+ acc = tl .permute (acc , (0 , 2 , 1 ))
630+ acc0 , acc1 = tl .split (acc )
631+ if BROADCAST_Y :
632+ y0 = y_desc .load ([0 , offs_wn ])
633+ else :
634+ y0 = y_desc .load ([offs_xm , offs_wn ])
635+ z0 = (acc0 + y0 .to (tl .float32 )).to (z_desc .dtype )
636+ z_desc .store ([offs_xm , offs_wn ], z0 )
637+ if BROADCAST_Y :
638+ y1 = y_desc .load ([0 , offs_wn + BLOCK_N // 2 ])
639+ else :
640+ y1 = y_desc .load ([offs_xm , offs_wn + BLOCK_N // 2 ])
641+ z1 = (acc1 + y1 .to (tl .float32 )).to (z_desc .dtype )
642+ z_desc .store ([offs_xm , offs_wn + BLOCK_N // 2 ], z1 )
643+ elif EPILOGUE_SUBTILE == 4 :
644+ acc = tl .reshape (accumulator , (BLOCK_M , 2 , BLOCK_N // 2 ))
645+ acc = tl .permute (acc , (0 , 2 , 1 ))
646+ acc_lo , acc_hi = tl .split (acc )
647+ acc_lo = tl .reshape (acc_lo , (BLOCK_M , 2 , BLOCK_N // 4 ))
648+ acc_lo = tl .permute (acc_lo , (0 , 2 , 1 ))
649+ acc0 , acc1 = tl .split (acc_lo )
650+ acc_hi = tl .reshape (acc_hi , (BLOCK_M , 2 , BLOCK_N // 4 ))
651+ acc_hi = tl .permute (acc_hi , (0 , 2 , 1 ))
652+ acc2 , acc3 = tl .split (acc_hi )
653+ if BROADCAST_Y :
654+ y0 = y_desc .load ([0 , offs_wn ])
655+ else :
656+ y0 = y_desc .load ([offs_xm , offs_wn ])
657+ z0 = (acc0 + y0 .to (tl .float32 )).to (z_desc .dtype )
658+ z_desc .store ([offs_xm , offs_wn ], z0 )
659+ if BROADCAST_Y :
660+ y1 = y_desc .load ([0 , offs_wn + BLOCK_N // 4 ])
661+ else :
662+ y1 = y_desc .load ([offs_xm , offs_wn + BLOCK_N // 4 ])
663+ z1 = (acc1 + y1 .to (tl .float32 )).to (z_desc .dtype )
664+ z_desc .store ([offs_xm , offs_wn + BLOCK_N // 4 ], z1 )
665+ if BROADCAST_Y :
666+ y2 = y_desc .load ([0 , offs_wn + 2 * BLOCK_N // 4 ])
667+ else :
668+ y2 = y_desc .load ([offs_xm , offs_wn + 2 * BLOCK_N // 4 ])
669+ z2 = (acc2 + y2 .to (tl .float32 )).to (z_desc .dtype )
670+ z_desc .store ([offs_xm , offs_wn + 2 * BLOCK_N // 4 ], z2 )
671+ if BROADCAST_Y :
672+ y3 = y_desc .load ([0 , offs_wn + 3 * BLOCK_N // 4 ])
673+ else :
674+ y3 = y_desc .load ([offs_xm , offs_wn + 3 * BLOCK_N // 4 ])
675+ z3 = (acc3 + y3 .to (tl .float32 )).to (z_desc .dtype )
676+ z_desc .store ([offs_xm , offs_wn + 3 * BLOCK_N // 4 ], z3 )
677+ else :
678+ tl .static_assert (False , "Unsupported EPILOGUE_SUBTILE value" )
679+
680+
530681@triton_autotune (
531- configs = get_mm_configs (pre_hook = _addmm_tma_set_block_size_hook ),
532- key = ["N" , "K" , "WARP_SPECIALIZE" ],
682+ configs = get_triton_persistent_configs (pre_hook = _addmm_tma_set_block_size_hook ),
683+ key = ["M" , "N" , "K" , "WARP_SPECIALIZE" ],
684+ prune_configs_by = {"early_config_prune" : _prune_persistent_autows_configs },
533685)
534686@triton .jit
535687def _addmm_fwd_tma_persistent (
@@ -548,6 +700,9 @@ def _addmm_fwd_tma_persistent(
548700 BROADCAST_Y : tl .constexpr ,
549701 WARP_SPECIALIZE : tl .constexpr ,
550702 NUM_SMS : tl .constexpr ,
703+ EPILOGUE_SUBTILE : tl .constexpr ,
704+ DATA_PARTITION_FACTOR : tl .constexpr ,
705+ USE_META_WS : tl .constexpr ,
551706):
552707 start_pid = tl .program_id (axis = 0 )
553708 num_pid_m = tl .cdiv (M , BLOCK_M )
@@ -557,27 +712,61 @@ def _addmm_fwd_tma_persistent(
557712
558713 num_pid_in_group = GROUP_M * num_pid_n
559714
560- for tile_id in tl .range (
561- start_pid , num_tiles , NUM_SMS , flatten = True , warp_specialize = WARP_SPECIALIZE
562- ):
563- pid_m , pid_n = _compute_pid (
564- tile_id , num_pid_in_group , num_pid_m , GROUP_M , NUM_SMS
565- )
566- offs_xm = pid_m * BLOCK_M
567- offs_wn = pid_n * BLOCK_N
568-
569- accumulator = tl .zeros ((BLOCK_M , BLOCK_N ), dtype = tl .float32 )
570- for k in tl .range (0 , k_tiles , warp_specialize = WARP_SPECIALIZE ):
571- offs_k = k * BLOCK_K
572- x = x_desc .load ([offs_xm , offs_k ])
573- w = w_desc .load ([offs_k , offs_wn ])
574- accumulator = tl .dot (x , w , accumulator , allow_tf32 = ALLOW_TF32 )
575- if BROADCAST_Y :
576- y = y_desc .load ([0 , offs_wn ])
577- else :
578- y = y_desc .load ([offs_xm , offs_wn ])
579- z = (accumulator + y .to (tl .float32 )).to (z_desc .dtype )
580- z_desc .store ([offs_xm , offs_wn ], z )
715+ if USE_META_WS :
716+ # Some arguments are only available in FBexperimental.
717+ # pyre-ignore[28]: smem_alloc_algo is FBexperimental
718+ for tile_id in tl .range (
719+ start_pid ,
720+ num_tiles ,
721+ NUM_SMS ,
722+ flatten = False ,
723+ warp_specialize = WARP_SPECIALIZE ,
724+ data_partition_factor = DATA_PARTITION_FACTOR ,
725+ smem_alloc_algo = 1 ,
726+ ):
727+ _addmm_persistent_tile_body (
728+ x_desc ,
729+ w_desc ,
730+ y_desc ,
731+ z_desc ,
732+ tile_id ,
733+ num_pid_in_group ,
734+ num_pid_m ,
735+ k_tiles ,
736+ BLOCK_M = BLOCK_M ,
737+ BLOCK_N = BLOCK_N ,
738+ BLOCK_K = BLOCK_K ,
739+ GROUP_M = GROUP_M ,
740+ ALLOW_TF32 = ALLOW_TF32 ,
741+ BROADCAST_Y = BROADCAST_Y ,
742+ NUM_SMS = NUM_SMS ,
743+ EPILOGUE_SUBTILE = EPILOGUE_SUBTILE ,
744+ INNER_WARP_SPECIALIZE = tl .constexpr (False ),
745+ )
746+ else :
747+ # Pure OAI Triton version.
748+ for tile_id in tl .range (
749+ start_pid , num_tiles , NUM_SMS , flatten = True , warp_specialize = WARP_SPECIALIZE
750+ ):
751+ _addmm_persistent_tile_body (
752+ x_desc ,
753+ w_desc ,
754+ y_desc ,
755+ z_desc ,
756+ tile_id ,
757+ num_pid_in_group ,
758+ num_pid_m ,
759+ k_tiles ,
760+ BLOCK_M = BLOCK_M ,
761+ BLOCK_N = BLOCK_N ,
762+ BLOCK_K = BLOCK_K ,
763+ GROUP_M = GROUP_M ,
764+ ALLOW_TF32 = ALLOW_TF32 ,
765+ BROADCAST_Y = BROADCAST_Y ,
766+ NUM_SMS = NUM_SMS ,
767+ EPILOGUE_SUBTILE = EPILOGUE_SUBTILE ,
768+ INNER_WARP_SPECIALIZE = WARP_SPECIALIZE ,
769+ )
581770
582771
583772@triton_autotune (
@@ -892,7 +1081,7 @@ def _addmm_fwd_tma_ws_persistent(
8921081 z = (result + y .to (tl .float32 )).to (z_desc .dtype )
8931082 z_buf_view = tlx .local_view (z_buffers , z_idx )
8941083 # If Y and Z are not shared wait for Z to be empty.
895- # If there are shared this already guarenteed .
1084+ # If there are shared this already guaranteed .
8961085 if not Y_Z_SHARED :
8971086 z_empty = tlx .local_view (z_empty_bars , z_idx )
8981087 tlx .barrier_wait (z_empty , z_load_phase ^ 1 )
@@ -1226,8 +1415,12 @@ def triton_addmm_fwd_tma_persistent(
12261415 x : torch .Tensor ,
12271416 w : torch .Tensor ,
12281417 y : torch .Tensor ,
1229- warp_specialize : bool = False ,
1418+ warp_specialize : bool | None = None ,
12301419) -> torch .Tensor :
1420+ _meta_ws = _use_meta_ws ()
1421+ if warp_specialize is None :
1422+ warp_specialize = _meta_ws
1423+
12311424 M , K = x .shape
12321425 _ , N = w .shape
12331426
@@ -1256,7 +1449,6 @@ def triton_addmm_fwd_tma_persistent(
12561449 NUM_SMS = torch .cuda .get_device_properties ("cuda" ).multi_processor_count
12571450
12581451 def grid (meta ):
1259- nonlocal x_desc , w_desc , z_desc
12601452 BLOCK_M = meta ["BLOCK_M" ]
12611453 BLOCK_N = meta ["BLOCK_N" ]
12621454 return (
@@ -1278,6 +1470,7 @@ def grid(meta):
12781470 BROADCAST_Y = is_y_1d ,
12791471 WARP_SPECIALIZE = warp_specialize ,
12801472 NUM_SMS = NUM_SMS ,
1473+ USE_META_WS = _meta_ws ,
12811474 )
12821475 return z
12831476
0 commit comments