Skip to content

Commit ca89f02

Browse files
committed
add tests
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
1 parent 0f1f226 commit ca89f02

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

tests/flash_attn/test_flash_attn_varlen_func.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from vllm_xpu_kernels.flash_attn_interface import flash_attn_varlen_func
1212

1313
NUM_HEADS = [(8, 2)]
14-
HEAD_SIZES = [64, 128, 192, 256]
14+
HEAD_SIZES = [64, 128, 192, 256, 512]
1515
BLOCK_SIZES = [64, 128]
1616
DTYPES = [torch.bfloat16]
1717
QDTYPES = [None]
@@ -365,6 +365,8 @@ def test_decode_with_paged_kv(
365365
# if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
366366
# pytest.skip("Flash attention with quantized inputs is only "
367367
# "supported on version 3 with bfloat16 base type")
368+
if head_size == 512 and block_size == 128:
369+
pytest.skip("skip test cases that may run out of SLM.")
368370
if num_heads == (16, 1) and head_size == 256:
369371
pytest.skip("skip test cases that may run out of SLM.")
370372
if block_size == 128 and num_blocks == 32768 and head_size >= 192:

0 commit comments

Comments
 (0)