2929import triton .language as tl
3030from generative_recommenders .ops .utils import is_sm100_plus , maybe_register_custom_op
3131
32+ try :
33+ # @manual=//triton:triton
34+ from triton .language .extra .subtile_ops import _split_n_2D
35+ except ImportError :
36+ _split_n_2D = None
37+
3238try :
3339 # @manual=//triton:triton
3440 import triton .language .extra .tlx as tlx # type: ignore
@@ -96,13 +102,20 @@ def _check_tma_alignment(
96102def _prune_persistent_autows_configs (configs , named_args , ** kwargs ): # noqa
97103 if not _use_meta_ws ():
98104 return configs
105+ BROADCAST_Y = kwargs .get ("BROADCAST_Y" , False )
99106 pruned = []
100107 for c in configs :
101108 BLOCK_M = c .kwargs .get ("BLOCK_M" , 0 )
109+ BLOCK_N = c .kwargs .get ("BLOCK_N" , 0 )
110+ EPILOGUE_SUBTILE = c .kwargs .get ("EPILOGUE_SUBTILE" , 1 )
102111 DP = c .kwargs .get ("DATA_PARTITION_FACTOR" , 1 )
103112 # DATA_PARTITION_FACTOR=2 is only supported with BLOCK_M=256
104113 if DP == 2 and BLOCK_M != 256 :
105114 continue
115+ if (BLOCK_N // EPILOGUE_SUBTILE ) < 32 :
116+ continue
117+ if BROADCAST_Y and (BLOCK_N // EPILOGUE_SUBTILE ) < 64 :
118+ continue
106119 pruned .append (c )
107120 return pruned
108121
@@ -120,7 +133,6 @@ def _prune_configs_for_tlx_persistent_addmm(configs, named_args, **kwargs): # n
120133 NUM_MMA_GROUPS = c .kwargs .get ("NUM_MMA_GROUPS" , 1 )
121134 BLOCK_M_SPLIT = BLOCK_M // NUM_MMA_GROUPS
122135 NUM_SMEM_BUFFERS = c .kwargs .get ("NUM_SMEM_BUFFERS" , 1 )
123- EPILOGUE_SUBTILE = c .kwargs .get ("EPILOGUE_SUBTILE" , 1 )
124136
125137 # Hardware constraint: Always make MMA tile 128.
126138 if BLOCK_M_SPLIT != 128 :
@@ -459,7 +471,7 @@ def get_triton_persistent_configs(pre_hook=None) -> List[triton.Config]:
459471 for block_n in [64 , 128 , 256 ]
460472 for block_k in [64 , 128 , 256 ]
461473 for num_stages in [2 , 3 , 4 ]
462- for subtile in [1 , 2 , 4 ]
474+ for subtile in [1 , 2 , 4 , 8 ]
463475 for DP in [1 , 2 ]
464476 ]
465477
@@ -619,65 +631,19 @@ def _addmm_persistent_tile_body(
619631
620632 # Epilogue subtiling breaks the store into multiple pieces to reduce
621633 # shared memory consumption and allow higher stage counts.
622- if EPILOGUE_SUBTILE == 1 :
623- if BROADCAST_Y :
624- y = y_desc .load ([0 , offs_wn ])
625- else :
626- y = y_desc .load ([offs_xm , offs_wn ])
627- z = (accumulator + y .to (tl .float32 )).to (z_desc .dtype )
628- z_desc .store ([offs_xm , offs_wn ], z )
629- elif EPILOGUE_SUBTILE == 2 :
630- acc = tl .reshape (accumulator , (BLOCK_M , 2 , BLOCK_N // 2 ))
631- acc = tl .permute (acc , (0 , 2 , 1 ))
632- acc0 , acc1 = tl .split (acc )
633- if BROADCAST_Y :
634- y0 = y_desc .load ([0 , offs_wn ])
635- else :
636- y0 = y_desc .load ([offs_xm , offs_wn ])
637- z0 = (acc0 + y0 .to (tl .float32 )).to (z_desc .dtype )
638- z_desc .store ([offs_xm , offs_wn ], z0 )
639- if BROADCAST_Y :
640- y1 = y_desc .load ([0 , offs_wn + BLOCK_N // 2 ])
641- else :
642- y1 = y_desc .load ([offs_xm , offs_wn + BLOCK_N // 2 ])
643- z1 = (acc1 + y1 .to (tl .float32 )).to (z_desc .dtype )
644- z_desc .store ([offs_xm , offs_wn + BLOCK_N // 2 ], z1 )
645- elif EPILOGUE_SUBTILE == 4 :
646- acc = tl .reshape (accumulator , (BLOCK_M , 2 , BLOCK_N // 2 ))
647- acc = tl .permute (acc , (0 , 2 , 1 ))
648- acc_lo , acc_hi = tl .split (acc )
649- acc_lo = tl .reshape (acc_lo , (BLOCK_M , 2 , BLOCK_N // 4 ))
650- acc_lo = tl .permute (acc_lo , (0 , 2 , 1 ))
651- acc0 , acc1 = tl .split (acc_lo )
652- acc_hi = tl .reshape (acc_hi , (BLOCK_M , 2 , BLOCK_N // 4 ))
653- acc_hi = tl .permute (acc_hi , (0 , 2 , 1 ))
654- acc2 , acc3 = tl .split (acc_hi )
655- if BROADCAST_Y :
656- y0 = y_desc .load ([0 , offs_wn ])
657- else :
658- y0 = y_desc .load ([offs_xm , offs_wn ])
659- z0 = (acc0 + y0 .to (tl .float32 )).to (z_desc .dtype )
660- z_desc .store ([offs_xm , offs_wn ], z0 )
661- if BROADCAST_Y :
662- y1 = y_desc .load ([0 , offs_wn + BLOCK_N // 4 ])
663- else :
664- y1 = y_desc .load ([offs_xm , offs_wn + BLOCK_N // 4 ])
665- z1 = (acc1 + y1 .to (tl .float32 )).to (z_desc .dtype )
666- z_desc .store ([offs_xm , offs_wn + BLOCK_N // 4 ], z1 )
667- if BROADCAST_Y :
668- y2 = y_desc .load ([0 , offs_wn + 2 * BLOCK_N // 4 ])
669- else :
670- y2 = y_desc .load ([offs_xm , offs_wn + 2 * BLOCK_N // 4 ])
671- z2 = (acc2 + y2 .to (tl .float32 )).to (z_desc .dtype )
672- z_desc .store ([offs_xm , offs_wn + 2 * BLOCK_N // 4 ], z2 )
634+ tl .static_assert (
635+ EPILOGUE_SUBTILE <= 8 ,
636+ "EPILOGUE_SUBTILE > 8 is not supported" ,
637+ )
638+ acc_subtiles = _split_n_2D (accumulator , EPILOGUE_SUBTILE ) # pyre-ignore[16]
639+ slice_size : tl .constexpr = BLOCK_N // EPILOGUE_SUBTILE
640+ for i in tl .static_range (EPILOGUE_SUBTILE ):
673641 if BROADCAST_Y :
674- y3 = y_desc .load ([0 , offs_wn + 3 * BLOCK_N // 4 ])
642+ y_i = y_desc .load ([0 , offs_wn + i * slice_size ])
675643 else :
676- y3 = y_desc .load ([offs_xm , offs_wn + 3 * BLOCK_N // 4 ])
677- z3 = (acc3 + y3 .to (tl .float32 )).to (z_desc .dtype )
678- z_desc .store ([offs_xm , offs_wn + 3 * BLOCK_N // 4 ], z3 )
679- else :
680- tl .static_assert (False , "Unsupported EPILOGUE_SUBTILE value" )
644+ y_i = y_desc .load ([offs_xm , offs_wn + i * slice_size ])
645+ z_i = (acc_subtiles [i ] + y_i .to (tl .float32 )).to (z_desc .dtype )
646+ z_desc .store ([offs_xm , offs_wn + i * slice_size ], z_i )
681647
682648
683649@triton_autotune (
0 commit comments