@@ -23,9 +23,10 @@ inline int get_num_splits(
2323 // The decode kernel iterates kv_tile-sized work units within each page,
2424 // not page-sized units. The dispatch (see paged_decode_utils.hpp::
2525 // dispatch_by_page_size) routes
26- // block_size == 16 -> kv_tile=_16 (SubgroupLayoutQK<_1,_1,_1>, SGPerWG=1)
27- // block_size == 32 -> kv_tile=_32 (SubgroupLayoutQK<_1,_2,_1>, SGPerWG=2)
28- // block_size > 0 && %% 64 == 0 -> kv_tile=_64 (SubgroupLayoutQK<_1,_4,_1>, SGPerWG=4)
26+ // block_size == 16 -> kv_tile=_16 (SubgroupLayoutQK<_1,_1,_1>,
27+ // SGPerWG=1) block_size == 32 -> kv_tile=_32
28+ // (SubgroupLayoutQK<_1,_2,_1>, SGPerWG=2) block_size > 0 && %% 64 == 0 ->
29+ // kv_tile=_64 (SubgroupLayoutQK<_1,_4,_1>, SGPerWG=4)
2930 int kv_tile;
3031 int sg_per_wg;
3132 int policy_split_cap;
@@ -86,7 +87,8 @@ inline int get_num_splits(
8687 // (4) Each split must process at least ~4 KV tiles to amortize overhead.
8788 int max_splits_tiles = std::max (1 , kv_tiles / 4 );
8889 // (5) Hard cap of 32 (beyond this the ReduceSplitK kernel dominates).
89- return std::max (1 , std::min ({splits, max_splits_tiles, 32 , policy_split_cap}));
90+ return std::max (
91+ 1 , std::min ({splits, max_splits_tiles, 32 , policy_split_cap}));
9092}
9193
9294std::vector<at::Tensor> mha_varlen_fwd (
@@ -249,7 +251,11 @@ std::vector<at::Tensor> mha_varlen_fwd(
249251 }
250252
251253 int num_kv_splits = num_splits.value_or (get_num_splits (
252- queue, batch_size, num_heads_q, num_heads_kv, effective_seqlen_k,
254+ queue,
255+ batch_size,
256+ num_heads_q,
257+ num_heads_kv,
258+ effective_seqlen_k,
253259 block_size));
254260
255261 at::Tensor tmp_out =
0 commit comments