5353ENABLE_FULL_TURNING_SPACE = False
5454
5555
56+ def _use_meta_ws () -> bool :
57+ """Check if Meta's warp specialization is available, enabled, and on SM100+."""
58+ return (
59+ is_sm100_plus ()
60+ and hasattr (triton , "knobs" )
61+ and hasattr (triton .knobs , "nvidia" )
62+ and triton .knobs .nvidia .use_meta_ws
63+ )
64+
65+
5666def _check_tma_alignment (
5767 x : torch .Tensor , w : torch .Tensor , y : torch .Tensor , min_alignment : int = 16
5868) -> bool :
@@ -83,6 +93,20 @@ def _check_tma_alignment(
8393 return (K % min_alignment == 0 ) and (N % min_alignment == 0 )
8494
8595
96+ def _prune_persistent_autows_configs (configs , named_args , ** kwargs ): # noqa
97+ if not _use_meta_ws ():
98+ return configs
99+ pruned = []
100+ for c in configs :
101+ BLOCK_M = c .kwargs .get ("BLOCK_M" , 0 )
102+ DP = c .kwargs .get ("DATA_PARTITION_FACTOR" , 1 )
103+ # DATA_PARTITION_FACTOR=2 is only supported with BLOCK_M=256
104+ if DP == 2 and BLOCK_M != 256 :
105+ continue
106+ pruned .append (c )
107+ return pruned
108+
109+
86110def _prune_configs_for_tlx_persistent_addmm (configs , named_args , ** kwargs ): # noqa
87111 M = named_args .get ("M" , 0 )
88112 N = named_args .get ("N" , 0 )
@@ -407,6 +431,39 @@ def _get_addmm_tma_ws_persistent_configs(pre_hook=None) -> List[triton.Config]:
407431 return configs
408432
409433
434+ def get_triton_persistent_configs (pre_hook = None ) -> List [triton .Config ]:
435+ if not _use_meta_ws ():
436+ configs = get_mm_configs (pre_hook = pre_hook )
437+ for c in configs :
438+ c .kwargs ["DATA_PARTITION_FACTOR" ] = 1
439+ c .kwargs ["EPILOGUE_SUBTILE" ] = 1
440+ return configs
441+ # TODO: Prune configs to best configs.
442+ return [
443+ triton .Config ( # pyre-ignore[28]
444+ {
445+ "BLOCK_M" : block_m ,
446+ "BLOCK_N" : block_n ,
447+ "BLOCK_K" : block_k ,
448+ "GROUP_M" : 8 ,
449+ "EPILOGUE_SUBTILE" : subtile ,
450+ "DATA_PARTITION_FACTOR" : DP ,
451+ },
452+ num_stages = num_stages ,
453+ num_warps = 4 ,
454+ pre_hook = pre_hook ,
455+ early_tma_store_lowering = 1 ,
456+ maxRegAutoWS = 255 ,
457+ )
458+ for block_m in [64 , 128 , 256 ]
459+ for block_n in [64 , 128 , 256 ]
460+ for block_k in [64 , 128 , 256 ]
461+ for num_stages in [2 , 3 , 4 ]
462+ for subtile in [1 , 2 , 4 ]
463+ for DP in [1 , 2 ]
464+ ]
465+
466+
410467@triton_cc (
411468 annotations = {
412469 "M" : "i32" ,
@@ -529,9 +586,104 @@ def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS):
529586 return pid_m , pid_n
530587
531588
589+ @triton .jit
590+ def _addmm_persistent_tile_body (
591+ x_desc ,
592+ w_desc ,
593+ y_desc ,
594+ z_desc ,
595+ tile_id ,
596+ num_pid_in_group ,
597+ num_pid_m ,
598+ k_tiles ,
599+ BLOCK_M : tl .constexpr ,
600+ BLOCK_N : tl .constexpr ,
601+ BLOCK_K : tl .constexpr ,
602+ GROUP_M : tl .constexpr ,
603+ ALLOW_TF32 : tl .constexpr ,
604+ BROADCAST_Y : tl .constexpr ,
605+ NUM_SMS : tl .constexpr ,
606+ EPILOGUE_SUBTILE : tl .constexpr ,
607+ INNER_WARP_SPECIALIZE : tl .constexpr ,
608+ ):
609+ pid_m , pid_n = _compute_pid (tile_id , num_pid_in_group , num_pid_m , GROUP_M , NUM_SMS )
610+ offs_xm = pid_m * BLOCK_M
611+ offs_wn = pid_n * BLOCK_N
612+
613+ accumulator = tl .zeros ((BLOCK_M , BLOCK_N ), dtype = tl .float32 )
614+ for k in tl .range (0 , k_tiles , warp_specialize = INNER_WARP_SPECIALIZE ):
615+ offs_k = k * BLOCK_K
616+ x = x_desc .load ([offs_xm , offs_k ])
617+ w = w_desc .load ([offs_k , offs_wn ])
618+ accumulator = tl .dot (x , w , accumulator , allow_tf32 = ALLOW_TF32 )
619+
620+ # Epilogue subtiling breaks the store into multiple pieces to reduce
621+ # shared memory consumption and allow higher stage counts.
622+ if EPILOGUE_SUBTILE == 1 :
623+ if BROADCAST_Y :
624+ y = y_desc .load ([0 , offs_wn ])
625+ else :
626+ y = y_desc .load ([offs_xm , offs_wn ])
627+ z = (accumulator + y .to (tl .float32 )).to (z_desc .dtype )
628+ z_desc .store ([offs_xm , offs_wn ], z )
629+ elif EPILOGUE_SUBTILE == 2 :
630+ acc = tl .reshape (accumulator , (BLOCK_M , 2 , BLOCK_N // 2 ))
631+ acc = tl .permute (acc , (0 , 2 , 1 ))
632+ acc0 , acc1 = tl .split (acc )
633+ if BROADCAST_Y :
634+ y0 = y_desc .load ([0 , offs_wn ])
635+ else :
636+ y0 = y_desc .load ([offs_xm , offs_wn ])
637+ z0 = (acc0 + y0 .to (tl .float32 )).to (z_desc .dtype )
638+ z_desc .store ([offs_xm , offs_wn ], z0 )
639+ if BROADCAST_Y :
640+ y1 = y_desc .load ([0 , offs_wn + BLOCK_N // 2 ])
641+ else :
642+ y1 = y_desc .load ([offs_xm , offs_wn + BLOCK_N // 2 ])
643+ z1 = (acc1 + y1 .to (tl .float32 )).to (z_desc .dtype )
644+ z_desc .store ([offs_xm , offs_wn + BLOCK_N // 2 ], z1 )
645+ elif EPILOGUE_SUBTILE == 4 :
646+ acc = tl .reshape (accumulator , (BLOCK_M , 2 , BLOCK_N // 2 ))
647+ acc = tl .permute (acc , (0 , 2 , 1 ))
648+ acc_lo , acc_hi = tl .split (acc )
649+ acc_lo = tl .reshape (acc_lo , (BLOCK_M , 2 , BLOCK_N // 4 ))
650+ acc_lo = tl .permute (acc_lo , (0 , 2 , 1 ))
651+ acc0 , acc1 = tl .split (acc_lo )
652+ acc_hi = tl .reshape (acc_hi , (BLOCK_M , 2 , BLOCK_N // 4 ))
653+ acc_hi = tl .permute (acc_hi , (0 , 2 , 1 ))
654+ acc2 , acc3 = tl .split (acc_hi )
655+ if BROADCAST_Y :
656+ y0 = y_desc .load ([0 , offs_wn ])
657+ else :
658+ y0 = y_desc .load ([offs_xm , offs_wn ])
659+ z0 = (acc0 + y0 .to (tl .float32 )).to (z_desc .dtype )
660+ z_desc .store ([offs_xm , offs_wn ], z0 )
661+ if BROADCAST_Y :
662+ y1 = y_desc .load ([0 , offs_wn + BLOCK_N // 4 ])
663+ else :
664+ y1 = y_desc .load ([offs_xm , offs_wn + BLOCK_N // 4 ])
665+ z1 = (acc1 + y1 .to (tl .float32 )).to (z_desc .dtype )
666+ z_desc .store ([offs_xm , offs_wn + BLOCK_N // 4 ], z1 )
667+ if BROADCAST_Y :
668+ y2 = y_desc .load ([0 , offs_wn + 2 * BLOCK_N // 4 ])
669+ else :
670+ y2 = y_desc .load ([offs_xm , offs_wn + 2 * BLOCK_N // 4 ])
671+ z2 = (acc2 + y2 .to (tl .float32 )).to (z_desc .dtype )
672+ z_desc .store ([offs_xm , offs_wn + 2 * BLOCK_N // 4 ], z2 )
673+ if BROADCAST_Y :
674+ y3 = y_desc .load ([0 , offs_wn + 3 * BLOCK_N // 4 ])
675+ else :
676+ y3 = y_desc .load ([offs_xm , offs_wn + 3 * BLOCK_N // 4 ])
677+ z3 = (acc3 + y3 .to (tl .float32 )).to (z_desc .dtype )
678+ z_desc .store ([offs_xm , offs_wn + 3 * BLOCK_N // 4 ], z3 )
679+ else :
680+ tl .static_assert (False , "Unsupported EPILOGUE_SUBTILE value" )
681+
682+
532683@triton_autotune (
533- configs = get_mm_configs (pre_hook = _addmm_tma_set_block_size_hook ),
534- key = ["N" , "K" , "WARP_SPECIALIZE" ],
684+ configs = get_triton_persistent_configs (pre_hook = _addmm_tma_set_block_size_hook ),
685+ key = ["M" , "N" , "K" , "WARP_SPECIALIZE" ],
686+ prune_configs_by = {"early_config_prune" : _prune_persistent_autows_configs },
535687)
536688@triton .jit
537689def _addmm_fwd_tma_persistent (
@@ -550,6 +702,9 @@ def _addmm_fwd_tma_persistent(
550702 BROADCAST_Y : tl .constexpr ,
551703 WARP_SPECIALIZE : tl .constexpr ,
552704 NUM_SMS : tl .constexpr ,
705+ EPILOGUE_SUBTILE : tl .constexpr ,
706+ DATA_PARTITION_FACTOR : tl .constexpr ,
707+ USE_META_WS : tl .constexpr ,
553708):
554709 start_pid = tl .program_id (axis = 0 )
555710 num_pid_m = tl .cdiv (M , BLOCK_M )
@@ -559,27 +714,61 @@ def _addmm_fwd_tma_persistent(
559714
560715 num_pid_in_group = GROUP_M * num_pid_n
561716
562- for tile_id in tl .range (
563- start_pid , num_tiles , NUM_SMS , flatten = True , warp_specialize = WARP_SPECIALIZE
564- ):
565- pid_m , pid_n = _compute_pid (
566- tile_id , num_pid_in_group , num_pid_m , GROUP_M , NUM_SMS
567- )
568- offs_xm = pid_m * BLOCK_M
569- offs_wn = pid_n * BLOCK_N
570-
571- accumulator = tl .zeros ((BLOCK_M , BLOCK_N ), dtype = tl .float32 )
572- for k in tl .range (0 , k_tiles , warp_specialize = WARP_SPECIALIZE ):
573- offs_k = k * BLOCK_K
574- x = x_desc .load ([offs_xm , offs_k ])
575- w = w_desc .load ([offs_k , offs_wn ])
576- accumulator = tl .dot (x , w , accumulator , allow_tf32 = ALLOW_TF32 )
577- if BROADCAST_Y :
578- y = y_desc .load ([0 , offs_wn ])
579- else :
580- y = y_desc .load ([offs_xm , offs_wn ])
581- z = (accumulator + y .to (tl .float32 )).to (z_desc .dtype )
582- z_desc .store ([offs_xm , offs_wn ], z )
717+ if USE_META_WS :
718+ # Some arguments are only available in FBexperimental.
719+ # pyre-ignore[28]: smem_alloc_algo is FBexperimental
720+ for tile_id in tl .range (
721+ start_pid ,
722+ num_tiles ,
723+ NUM_SMS ,
724+ flatten = False ,
725+ warp_specialize = WARP_SPECIALIZE ,
726+ data_partition_factor = DATA_PARTITION_FACTOR ,
727+ smem_alloc_algo = 1 ,
728+ ):
729+ _addmm_persistent_tile_body (
730+ x_desc ,
731+ w_desc ,
732+ y_desc ,
733+ z_desc ,
734+ tile_id ,
735+ num_pid_in_group ,
736+ num_pid_m ,
737+ k_tiles ,
738+ BLOCK_M = BLOCK_M ,
739+ BLOCK_N = BLOCK_N ,
740+ BLOCK_K = BLOCK_K ,
741+ GROUP_M = GROUP_M ,
742+ ALLOW_TF32 = ALLOW_TF32 ,
743+ BROADCAST_Y = BROADCAST_Y ,
744+ NUM_SMS = NUM_SMS ,
745+ EPILOGUE_SUBTILE = EPILOGUE_SUBTILE ,
746+ INNER_WARP_SPECIALIZE = tl .constexpr (False ),
747+ )
748+ else :
749+ # Pure OAI Triton version.
750+ for tile_id in tl .range (
751+ start_pid , num_tiles , NUM_SMS , flatten = True , warp_specialize = WARP_SPECIALIZE
752+ ):
753+ _addmm_persistent_tile_body (
754+ x_desc ,
755+ w_desc ,
756+ y_desc ,
757+ z_desc ,
758+ tile_id ,
759+ num_pid_in_group ,
760+ num_pid_m ,
761+ k_tiles ,
762+ BLOCK_M = BLOCK_M ,
763+ BLOCK_N = BLOCK_N ,
764+ BLOCK_K = BLOCK_K ,
765+ GROUP_M = GROUP_M ,
766+ ALLOW_TF32 = ALLOW_TF32 ,
767+ BROADCAST_Y = BROADCAST_Y ,
768+ NUM_SMS = NUM_SMS ,
769+ EPILOGUE_SUBTILE = EPILOGUE_SUBTILE ,
770+ INNER_WARP_SPECIALIZE = WARP_SPECIALIZE ,
771+ )
583772
584773
585774@triton_autotune (
@@ -894,7 +1083,7 @@ def _addmm_fwd_tma_ws_persistent(
8941083 z = (result + y .to (tl .float32 )).to (z_desc .dtype )
8951084 z_buf_view = tlx .local_view (z_buffers , z_idx )
8961085 # If Y and Z are not shared wait for Z to be empty.
897- # If there are shared this already guarenteed .
1086+ # If there are shared this already guaranteed .
8981087 if not Y_Z_SHARED :
8991088 z_empty = tlx .local_view (z_empty_bars , z_idx )
9001089 tlx .barrier_wait (z_empty , z_load_phase ^ 1 )
@@ -1228,8 +1417,12 @@ def triton_addmm_fwd_tma_persistent(
12281417 x : torch .Tensor ,
12291418 w : torch .Tensor ,
12301419 y : torch .Tensor ,
1231- warp_specialize : bool = False ,
1420+ warp_specialize : bool | None = None ,
12321421) -> torch .Tensor :
1422+ _meta_ws = _use_meta_ws ()
1423+ if warp_specialize is None :
1424+ warp_specialize = _meta_ws
1425+
12331426 M , K = x .shape
12341427 _ , N = w .shape
12351428
@@ -1258,7 +1451,6 @@ def triton_addmm_fwd_tma_persistent(
12581451 NUM_SMS = torch .cuda .get_device_properties ("cuda" ).multi_processor_count
12591452
12601453 def grid (meta ):
1261- nonlocal x_desc , w_desc , z_desc
12621454 BLOCK_M = meta ["BLOCK_M" ]
12631455 BLOCK_N = meta ["BLOCK_N" ]
12641456 return (
@@ -1280,6 +1472,7 @@ def grid(meta):
12801472 BROADCAST_Y = is_y_1d ,
12811473 WARP_SPECIALIZE = warp_specialize ,
12821474 NUM_SMS = NUM_SMS ,
1475+ USE_META_WS = _meta_ws ,
12831476 )
12841477 return z
12851478
0 commit comments