Skip to content

Commit

Permalink
bugfix: fix the behavior of MLA kernel when kv-length is 0 (#868)
Browse files Browse the repository at this point in the history
The scheduling algorithm in #863 do not consider some requests have
kv-cache length 0, this PR fixes the issue.
  • Loading branch information
yzh119 authored Feb 17, 2025
1 parent 7cd000b commit 6ec3bae
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
4 changes: 3 additions & 1 deletion include/flashinfer/attention/scheduler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1134,7 +1134,8 @@ inline cudaError_t MLAPlan(void* float_buffer, size_t float_workspace_size_in_by
qo_indptr_h[i] * num_heads +
std::min((qo_tile_idx + 1) * cluster_tile_q, packed_qo_len);
}
while (remaining_len > 0) {
bool zero_kv_len = (remaining_len == 0);
while (remaining_len > 0 || zero_kv_len) {
auto [cluster_idx, accum_cost] = cluster_cost_heap.pop();
int actual_len = std::min(remaining_len, kv_len_limit);
cluster_cost_heap.insert(
Expand All @@ -1154,6 +1155,7 @@ inline cudaError_t MLAPlan(void* float_buffer, size_t float_workspace_size_in_by
cluster_kv_end[cluster_idx].push_back(kv_start + actual_len);
remaining_len -= actual_len;
kv_start += actual_len;
if (zero_kv_len) break;
}
split_kv_count += int(split_kv);
}
Expand Down
5 changes: 3 additions & 2 deletions tests/test_deepseek_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def generate_kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads):


@pytest.mark.parametrize("batch_size", [1, 17, 37])
@pytest.mark.parametrize("kv_len", [17, 33, 96, 97, 114, 514, 1024])
@pytest.mark.parametrize("kv_len", [0, 17, 33, 96, 97, 114, 514, 1024])
@pytest.mark.parametrize("qo_len", [1, 17, 37, 77])
@pytest.mark.parametrize("num_heads", [4, 32, 128])
@pytest.mark.parametrize("causal", [False, True])
Expand Down Expand Up @@ -243,7 +243,8 @@ def test_batch_mla_page_attention(
o_ref, lse_ref = attention_ref(batch_size, q, k, v, causal, sm_scale)
lse_ref = lse_ref.flatten(0, 1)
torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3)
if kv_len != 0:
torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3)

# test with pre-allocated output
o_buffer = torch.empty_like(o)
Expand Down

0 comments on commit 6ec3bae

Please sign in to comment.