Skip to content

Commit ec05cc8

Browse files
authored
update FA/gemm for autoWS annotation change (#1021)
1 parent 93f4062 commit ec05cc8

2 files changed

Lines changed: 7 additions & 2 deletions

File tree

tritonbench/kernels/blackwell_triton_fused_attention.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@ def _attn_fwd_inner_oss_dp(
173173
hi,
174174
BLOCK_N,
175175
warp_specialize=warp_specialize,
176+
merge_epilogue=True,
177+
separate_epilogue_store=True,
176178
# disallow_acc_multi_buffer=True,
177179
):
178180
start_n = tl.multiple_of(start_n, BLOCK_N)
@@ -648,6 +650,8 @@ def _attn_fwd_persist(
648650
0,
649651
tiles_per_sm,
650652
warp_specialize=warp_specialize and OUTER_LOOP,
653+
merge_epilogue=True,
654+
separate_epilogue_store=True,
651655
data_partition_factor=DP_FACTOR,
652656
):
653657
pid = tile_idx % n_tile_num
@@ -905,7 +909,7 @@ def _attn_bwd_dkdv(
905909
0,
906910
num_steps,
907911
warp_specialize=True,
908-
merge_epilogue=True,
912+
merge_epilogue_to_computation=True,
909913
tmem_alloc_algo=2,
910914
smem_alloc_algo=1,
911915
smem_budget=200000,
@@ -1307,7 +1311,7 @@ def _attn_bwd_persist(
13071311
0,
13081312
tiles_per_sm,
13091313
warp_specialize=True,
1310-
merge_epilogue=True,
1314+
merge_epilogue_to_computation=True,
13111315
tmem_alloc_algo=2,
13121316
smem_alloc_algo=1,
13131317
smem_budget=200000,

tritonbench/operators/gemm/warp_spec_persistent_matmul.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,7 @@ def matmul_kernel_tma_persistent(
537537
warp_specialize=WARP_SPECIALIZE,
538538
data_partition_factor=DATA_PARTITION_FACTOR,
539539
smem_alloc_algo=1,
540+
separate_epilogue_store=True,
540541
):
541542
tile_id_c = _matmul_tma_persistent_loop_body(
542543
tile_id,

0 commit comments

Comments
 (0)