Skip to content

Commit 182fe48

Browse files
committed
make format happy
Signed-off-by: baodii <di.bao@intel.com>
1 parent ec179e7 commit 182fe48

1 file changed

Lines changed: 11 additions & 5 deletions

File tree

csrc/flash_attn/flash_api.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@ inline int get_num_splits(
2323
// The decode kernel iterates kv_tile-sized work units within each page,
2424
// not page-sized units. The dispatch (see paged_decode_utils.hpp::
2525
// dispatch_by_page_size) routes
26-
// block_size == 16 -> kv_tile=_16 (SubgroupLayoutQK<_1,_1,_1>, SGPerWG=1)
27-
// block_size == 32 -> kv_tile=_32 (SubgroupLayoutQK<_1,_2,_1>, SGPerWG=2)
28-
// block_size > 0 && %% 64 == 0 -> kv_tile=_64 (SubgroupLayoutQK<_1,_4,_1>, SGPerWG=4)
26+
// block_size == 16 -> kv_tile=_16 (SubgroupLayoutQK<_1,_1,_1>,
27+
// SGPerWG=1) block_size == 32 -> kv_tile=_32
28+
// (SubgroupLayoutQK<_1,_2,_1>, SGPerWG=2) block_size > 0 && %% 64 == 0 ->
29+
// kv_tile=_64 (SubgroupLayoutQK<_1,_4,_1>, SGPerWG=4)
2930
int kv_tile;
3031
int sg_per_wg;
3132
int policy_split_cap;
@@ -86,7 +87,8 @@ inline int get_num_splits(
8687
// (4) Each split must process at least ~4 KV tiles to amortize overhead.
8788
int max_splits_tiles = std::max(1, kv_tiles / 4);
8889
// (5) Hard cap of 32 (beyond this the ReduceSplitK kernel dominates).
89-
return std::max(1, std::min({splits, max_splits_tiles, 32, policy_split_cap}));
90+
return std::max(
91+
1, std::min({splits, max_splits_tiles, 32, policy_split_cap}));
9092
}
9193

9294
std::vector<at::Tensor> mha_varlen_fwd(
@@ -249,7 +251,11 @@ std::vector<at::Tensor> mha_varlen_fwd(
249251
}
250252

251253
int num_kv_splits = num_splits.value_or(get_num_splits(
252-
queue, batch_size, num_heads_q, num_heads_kv, effective_seqlen_k,
254+
queue,
255+
batch_size,
256+
num_heads_q,
257+
num_heads_kv,
258+
effective_seqlen_k,
253259
block_size));
254260

255261
at::Tensor tmp_out =

0 commit comments

Comments
 (0)