Open
Description
I noticed this issue when trying to use BatchPrefillWithRaggedKVCacheWrapper with deepseek v3 with custom mask. I was able to repro the issue with the tests/test_batch_prefill_kernels.py::test_batch_prefill_with_ragged_kv_cache_custom_mask
I modified test_batch_prefill_with_ragged_kv_cache_custom_mask
to allow passing v_head_dim
and with that I find that the test fails for many batch sizes when v_head_dim
is not the same as head_dim
:
@pytest.mark.parametrize("batch_size", list(range(1, 13)))
@pytest.mark.parametrize("kv_len", [54])
@pytest.mark.parametrize("qo_len", [37])
@pytest.mark.parametrize("num_kv_heads", [4])
@pytest.mark.parametrize("num_qo_heads", [4])
@pytest.mark.parametrize("head_dim", [192])
@pytest.mark.parametrize("v_head_dim", [128])
@pytest.mark.parametrize("pos_encoding_mode", ["NONE"])
@pytest.mark.parametrize("logits_soft_cap", [0.0])
@pytest.mark.parametrize("return_lse", [True])
def test_batch_prefill_with_ragged_kv_cache_custom_mask(
batch_size,
kv_len,
qo_len,
num_kv_heads,
num_qo_heads,
head_dim,
v_head_dim,
pos_encoding_mode,
logits_soft_cap,
return_lse,
):
tests/test_batch_prefill_kernels.py::test_batch_prefill_with_ragged_kv_cache_custom_mask[True-0.0-NONE-128-192-4-4-37-54-1] FAILED [ 8%]
tests/test_batch_prefill_kernels.py::test_batch_prefill_with_ragged_kv_cache_custom_mask[True-0.0-NONE-128-192-4-4-37-54-2] FAILED [ 16%]
tests/test_batch_prefill_kernels.py::test_batch_prefill_with_ragged_kv_cache_custom_mask[True-0.0-NONE-128-192-4-4-37-54-3] FAILED [ 25%]
tests/test_batch_prefill_kernels.py::test_batch_prefill_with_ragged_kv_cache_custom_mask[True-0.0-NONE-128-192-4-4-37-54-4] FAILED [ 33%]
tests/test_batch_prefill_kernels.py::test_batch_prefill_with_ragged_kv_cache_custom_mask[True-0.0-NONE-128-192-4-4-37-54-5] FAILED [ 41%]
tests/test_batch_prefill_kernels.py::test_batch_prefill_with_ragged_kv_cache_custom_mask[True-0.0-NONE-128-192-4-4-37-54-6] FAILED [ 50%]
tests/test_batch_prefill_kernels.py::test_batch_prefill_with_ragged_kv_cache_custom_mask[True-0.0-NONE-128-192-4-4-37-54-7] PASSED [ 58%]
tests/test_batch_prefill_kernels.py::test_batch_prefill_with_ragged_kv_cache_custom_mask[True-0.0-NONE-128-192-4-4-37-54-8] FAILED [ 66%]
tests/test_batch_prefill_kernels.py::test_batch_prefill_with_ragged_kv_cache_custom_mask[True-0.0-NONE-128-192-4-4-37-54-9] PASSED [ 75%]
tests/test_batch_prefill_kernels.py::test_batch_prefill_with_ragged_kv_cache_custom_mask[True-0.0-NONE-128-192-4-4-37-54-10] FAILED [ 83%]
tests/test_batch_prefill_kernels.py::test_batch_prefill_with_ragged_kv_cache_custom_mask[True-0.0-NONE-128-192-4-4-37-54-11] PASSED [ 91%]
tests/test_batch_prefill_kernels.py::test_batch_prefill_with_ragged_kv_cache_custom_mask[True-0.0-NONE-128-192-4-4-37-54-12] PASSED [100%]
With head_dim == v_head_dim == 128, all tests pass.
With head_dim == v_head_dim == 192, also all tests pass.
It seems to be a combination of head_dim != v_head_dim
and batch size. Practically, it means I'm unable to use custom mask with an individual request.
Metadata
Assignees
Labels
No labels
Activity