File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -173,6 +173,8 @@ def _attn_fwd_inner_oss_dp(
173173 hi ,
174174 BLOCK_N ,
175175 warp_specialize = warp_specialize ,
176+ merge_epilogue = True ,
177+ separate_epilogue_store = True ,
176178 # disallow_acc_multi_buffer=True,
177179 ):
178180 start_n = tl .multiple_of (start_n , BLOCK_N )
@@ -648,6 +650,8 @@ def _attn_fwd_persist(
648650 0 ,
649651 tiles_per_sm ,
650652 warp_specialize = warp_specialize and OUTER_LOOP ,
653+ merge_epilogue = True ,
654+ separate_epilogue_store = True ,
651655 data_partition_factor = DP_FACTOR ,
652656 ):
653657 pid = tile_idx % n_tile_num
@@ -905,7 +909,7 @@ def _attn_bwd_dkdv(
905909 0 ,
906910 num_steps ,
907911 warp_specialize = True ,
908- merge_epilogue = True ,
912+ merge_epilogue_to_computation = True ,
909913 tmem_alloc_algo = 2 ,
910914 smem_alloc_algo = 1 ,
911915 smem_budget = 200000 ,
@@ -1307,7 +1311,7 @@ def _attn_bwd_persist(
13071311 0 ,
13081312 tiles_per_sm ,
13091313 warp_specialize = True ,
1310- merge_epilogue = True ,
1314+ merge_epilogue_to_computation = True ,
13111315 tmem_alloc_algo = 2 ,
13121316 smem_alloc_algo = 1 ,
13131317 smem_budget = 200000 ,
Original file line number Diff line number Diff line change @@ -537,6 +537,7 @@ def matmul_kernel_tma_persistent(
537537 warp_specialize = WARP_SPECIALIZE ,
538538 data_partition_factor = DATA_PARTITION_FACTOR ,
539539 smem_alloc_algo = 1 ,
540+ separate_epilogue_store = True ,
540541 ):
541542 tile_id_c = _matmul_tma_persistent_loop_body (
542543 tile_id ,
You can’t perform that action at this time.
0 commit comments