Skip to content

BatchPrefillWithRaggedKVCacheWrapper custom_mask doesn't work for small batch sizes when head_dim_qk != head_dim_vo #860

Open
@pankajroark

Description

I noticed this issue when trying to use BatchPrefillWithRaggedKVCacheWrapper with deepseek v3 with custom mask. I was able to repro the issue with the tests/test_batch_prefill_kernels.py::test_batch_prefill_with_ragged_kv_cache_custom_mask

I modified test_batch_prefill_with_ragged_kv_cache_custom_mask to allow passing v_head_dim and with that I find that the test fails for many batch sizes when v_head_dim is not the same as head_dim:

@pytest.mark.parametrize("batch_size", list(range(1, 13)))
@pytest.mark.parametrize("kv_len", [54])
@pytest.mark.parametrize("qo_len", [37])
@pytest.mark.parametrize("num_kv_heads", [4])
@pytest.mark.parametrize("num_qo_heads", [4])
@pytest.mark.parametrize("head_dim", [192])
@pytest.mark.parametrize("v_head_dim", [128])
@pytest.mark.parametrize("pos_encoding_mode", ["NONE"])
@pytest.mark.parametrize("logits_soft_cap", [0.0])
@pytest.mark.parametrize("return_lse", [True])
def test_batch_prefill_with_ragged_kv_cache_custom_mask(
    batch_size,
    kv_len,
    qo_len,
    num_kv_heads,
    num_qo_heads,
    head_dim,
    v_head_dim,
    pos_encoding_mode,
    logits_soft_cap,
    return_lse,
):
tests/test_batch_prefill_kernels.py::test_batch_prefill_with_ragged_kv_cache_custom_mask[True-0.0-NONE-128-192-4-4-37-54-1] FAILED [  8%]
tests/test_batch_prefill_kernels.py::test_batch_prefill_with_ragged_kv_cache_custom_mask[True-0.0-NONE-128-192-4-4-37-54-2] FAILED [ 16%]
tests/test_batch_prefill_kernels.py::test_batch_prefill_with_ragged_kv_cache_custom_mask[True-0.0-NONE-128-192-4-4-37-54-3] FAILED [ 25%]
tests/test_batch_prefill_kernels.py::test_batch_prefill_with_ragged_kv_cache_custom_mask[True-0.0-NONE-128-192-4-4-37-54-4] FAILED [ 33%]
tests/test_batch_prefill_kernels.py::test_batch_prefill_with_ragged_kv_cache_custom_mask[True-0.0-NONE-128-192-4-4-37-54-5] FAILED [ 41%]
tests/test_batch_prefill_kernels.py::test_batch_prefill_with_ragged_kv_cache_custom_mask[True-0.0-NONE-128-192-4-4-37-54-6] FAILED [ 50%]
tests/test_batch_prefill_kernels.py::test_batch_prefill_with_ragged_kv_cache_custom_mask[True-0.0-NONE-128-192-4-4-37-54-7] PASSED [ 58%]
tests/test_batch_prefill_kernels.py::test_batch_prefill_with_ragged_kv_cache_custom_mask[True-0.0-NONE-128-192-4-4-37-54-8] FAILED [ 66%]
tests/test_batch_prefill_kernels.py::test_batch_prefill_with_ragged_kv_cache_custom_mask[True-0.0-NONE-128-192-4-4-37-54-9] PASSED [ 75%]
tests/test_batch_prefill_kernels.py::test_batch_prefill_with_ragged_kv_cache_custom_mask[True-0.0-NONE-128-192-4-4-37-54-10] FAILED [ 83%]
tests/test_batch_prefill_kernels.py::test_batch_prefill_with_ragged_kv_cache_custom_mask[True-0.0-NONE-128-192-4-4-37-54-11] PASSED [ 91%]
tests/test_batch_prefill_kernels.py::test_batch_prefill_with_ragged_kv_cache_custom_mask[True-0.0-NONE-128-192-4-4-37-54-12] PASSED [100%]

With head_dim == v_head_dim == 128, all tests pass.
With head_dim == v_head_dim == 192, also all tests pass.

It seems to be a combination of head_dim != v_head_dim and batch size. Practically, it means I'm unable to use custom mask with an individual request.

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions