Skip to content

Commit 9499558

Browse files
authored
Merge branch 'Dao-AILab:main' into main
2 parents 56d9d23 + 4178915 commit 9499558

24 files changed

Lines changed: 2923 additions & 462 deletions

csrc/flash_attn/flash_api.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -698,11 +698,14 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
698698
params.page_block_size = page_block_size;
699699
// Keep references to these tensors to extend their lifetime
700700
at::Tensor softmax_lse_accum, out_accum;
701-
if (paged_KV || seqlenq_ngroups_swapped) {
701+
if (seqlenq_ngroups_swapped) {
702702
std::tie(softmax_lse_accum, out_accum) =
703703
set_params_splitkv(params, batch_size, num_heads, head_size,
704704
max_seqlen_k, max_seqlen_q, head_size_rounded,
705705
p_dropout, num_splits, get_num_sm(get_current_device()), opts);
706+
} else if (paged_KV) {
707+
TORCH_CHECK(num_splits <= 1, "num_splits > 1 is not supported for varlen paged KV");
708+
params.num_splits = num_splits;
706709
}
707710

708711
if (leftpad_k_.has_value()) {

flash_attn/cute/block_info.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,14 +143,14 @@ def get_n_block_max_for_m_block(
143143
self,
144144
seqlen_info: SeqlenInfoQK,
145145
m_block: Int32,
146-
n_block_global_max: Int32,
147146
) -> Int32:
147+
n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.tile_n)
148148
if const_expr(self.is_causal or self.window_size_right is not None):
149149
m_idx_max = (m_block + 1) * self.tile_m
150150
if const_expr(self.qhead_per_kvhead_packgqa > 1):
151151
m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa)
152152
n_idx_right = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q
153153
if const_expr(self.window_size_right is not None):
154154
n_idx_right += self.window_size_right
155-
return min(n_block_global_max, cute.ceil_div(n_idx_right, self.tile_n))
156-
return n_block_global_max
155+
n_block_max = min(n_block_max, cute.ceil_div(n_idx_right, self.tile_n))
156+
return n_block_max

0 commit comments

Comments
 (0)