-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Open
Description
Checklist
- I searched related issues but found no solution.
- The bug persists in the latest version.
- Issues without environment info and a minimal reproducible demo are hard to resolve and may receive no feedback.
- If this is not a bug report but a general question, please start a discussion at https://github.com/sgl-project/sglang/discussions. Otherwise, it will be closed.
- Please use English. Otherwise, it will be closed.
Describe the bug
When we override enable_fused_set_kv_buffer to False and run test/registered/dllm/test_llada2_mini.py, this test will fail.
The root cause is that, the flashinfer attention backend skip the save_kv_cache, which is incorrect. After removing the following lines, the test will pass again.
sglang/python/sglang/srt/layers/attention/flashinfer_backend.py
Lines 815 to 816 in d73f06f
| if save_kv_cache and layer.attn_type == AttentionType.ENCODER_ONLY: | |
| save_kv_cache = False |
Reproduction
- Modify this function. Always return False
sglang/python/sglang/srt/models/utils.py
Lines 107 to 114 in d73f06f
def enable_fused_set_kv_buffer(forward_batch: ForwardBatch): """Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache.""" return ( _is_cuda and hasattr(forward_batch.token_to_kv_pool, "dtype") and forward_batch.token_to_kv_pool.dtype == torch.bfloat16 and not isinstance(forward_batch.token_to_kv_pool, SWAKVPool) ) - Run
pytest -xss test/registered/dllm/test_llada2_mini.py
Environment
H200
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels