feat: Support Hybrid Attention Models with extract_hidden_states#157
Draft
rahul-tuli wants to merge 4 commits intomainfrom
Draft
feat: Support Hybrid Attention Models with extract_hidden_states#157rahul-tuli wants to merge 4 commits intomainfrom
extract_hidden_states#157rahul-tuli wants to merge 4 commits intomainfrom
Conversation
Add CacheOnlySpec(KVCacheSpec) as a dedicated spec for CacheOnlyAttentionLayer used by the extract_hidden_states speculative decoding method. Previously MLAAttentionSpec was used as a proxy, which caused two bugs: - As an AttentionSpec subclass it entered unify_hybrid_kv_cache_specs, causing failures on hybrid models (Qwen3-MoE, Gemma4) where attention types cannot be unified. - The CacheOnly group was incorrectly excluded from block-table and slot-mapping construction (handled in a follow-up commit). CacheOnlySpec carries no factor-of-2 key/value split: page_size_bytes = block_size * num_hidden_states * hidden_size * dtype_bytes Changes: - vllm/v1/kv_cache_interface.py: add CacheOnlySpec; guard UniformTypeKVCacheSpecs.is_uniform_type against it - vllm/v1/core/kv_cache_utils.py: pre-extract CacheOnly specs before routing in get_kv_cache_groups; deduct CacheOnly memory from shared budget in get_kv_cache_config_from_groups and _max_memory_usage_bytes_from_groups - vllm/v1/core/single_type_kv_cache_manager.py: add CacheOnlyManager (prefix cache disabled); register in spec_manager_map - vllm/model_executor/models/extract_hidden_states.py: return CacheOnlySpec from get_kv_cache_spec() instead of MLAAttentionSpec - tests/v1/core/test_kv_cache_utils.py: 5 new tests covering page-size formula, group routing, hybrid routing, and memory budget Co-authored-by: Claude Signed-off-by: Rahul-Tuli <rtuli@redhat.com>
…ing, and reshape CacheOnlySpec groups need a real block table entry and correct slot mappings so hidden states are written to the right paged-block slots. Previously these groups were skipped alongside EncoderOnlyAttentionSpec, leaving input_batch.block_table without a CacheOnly entry (KeyError). - _get_slot_mappings: split EncoderOnly/CacheOnly branches; for CacheOnly use the real block table and fill padding with 0 (not -1) — basic_cache uses plain tensor indexing and cannot handle negative slot IDs - _reinitialize_input_batch: include CacheOnly in block_sizes so input_batch.block_table gains an entry; decouple kernel_block_sizes (which skips CacheOnly) using iter()/next() - _reshape_kv_cache_tensors: reshape CacheOnly tensors to [num_blocks, block_size, num_hidden_states, hidden_size] - prepare_kernel_block_sizes (worker/utils.py): skip CacheOnlySpec — it has no attention kernel and needs no kernel block size entry Co-authored-by: Claude Signed-off-by: Rahul-Tuli <rtuli@redhat.com>
propose() was passing common_attn_metadata.slot_mapping (the main attention group's slot mapping, which uses -1 padding and different block IDs) to _get_slot_mapping. This caused hidden states to be written to incorrect or out-of-bounds slots in the CacheOnly KV cache. Extract the slot mapping for the CacheOnly layer from the per-layer slot_mappings dict (slot_mappings[attn_layer_names[0]]) so the proposer writes hidden states to the same block slots that the KV connector reads from in extract_from_kv_cache. Co-authored-by: Claude Signed-off-by: Rahul-Tuli <rtuli@redhat.com>
… CacheOnly Without SupportsHMA, any kv_transfer_config triggers a blanket HMA disable in VllmConfig._post_init. With CacheOnlySpec now as a separate group, disabling HMA would cause unify_hybrid_kv_cache_specs to encounter CacheOnlySpec alongside attention specs and fail. SupportsHMA opts the connector out of that disable path. Additional connector fixes required once CacheOnly is its own group: - _cache_group_idx: detect the CacheOnly group index at __init__ so all block_ids accesses use the correct group (was hardcoded to [0]) - request_finished_all_groups: delegate to request_finished using block_ids[_cache_group_idx] (required by SupportsHMA contract) - build_connector_meta / request_finished: use _cache_group_idx instead of literal 0 when reading block_ids Co-authored-by: Claude Signed-off-by: Rahul-Tuli <rtuli@redhat.com>
Member
Author
|
Tracked in JIRA: https://issues.redhat.com/browse/INFERENG-5961 JIRA Details:
|
fynnsu
reviewed
Apr 8, 2026
| for i, g in enumerate(groups) | ||
| if isinstance(g.kv_cache_spec, CacheOnlySpec) | ||
| ), | ||
| 0, |
There was a problem hiding this comment.
We have 0 as a fallback here, but if a CacheOnlySpec isn't found doesn't that mean the model isn't setup correctly? Perhaps we should be failing here instead.
Member
Author
There was a problem hiding this comment.
Good catch! You're correct, 0 is wrong here, raising an error is the right thing
| for spec in kv_cache_specs.values() | ||
| ) | ||
| elif isinstance(one_spec, CacheOnlySpec): | ||
| return False |
There was a problem hiding this comment.
Is this always called with all layers (in which case it would be safe to assume there's at least two specs)?
If not, (if this is ever called with a subset of the layers, like just the draft layers), then we shouldn't assume this and should instead check the layers.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR adds support for extracting hidden states with Hybrid Attention Models (for example:
Qwen/Qwen3.5-9B)The core issue was that
CacheOnlyAttentionLayer— the single-layer "draft model"used by
extract_hidden_states— was previously registered withMLAAttentionSpec.This caused two compounding bugs:
HMA was unconditionally disabled when any
kv_transfer_configwas set. WithMLAAttentionSpec, the CacheOnly layer appeared as a regular attention layer,causing
unify_hybrid_kv_cache_specsto fail on hybrid models (Qwen3.5,Gemma4) where attention types cannot be unified. Even when HMA was turned on
unification with
MLAAttentionSpecwas not guaranteed to succeed. (failed for Gemma)A dedicated
CacheOnlySpec(KVCacheSpec)is introduced with semantics that matchhidden-state storage:
Because
CacheOnlySpecdoes not inherit fromAttentionSpec, all existingisinstance(spec, AttentionSpec)guards naturally exclude it, and it is routedthrough a dedicated code path in each subsystem.
Memory budget
CacheOnlySpecgroups are pre-extracted from the main spec map before routing.get_kv_cache_config_from_groupsdeducts their memory cost from the shared pool:Both the main KV cache tensors and the CacheOnly tensor are allocated for the same
num_blocks, ensuring that any block ID valid for the main cache is also valid forthe CacheOnly cache. The CacheOnly tensor is given its own independent
KVCacheTensorentry (not shared with attention layers).
HMA interaction
ExampleHiddenStatesConnectornow declaresSupportsHMA. TheVllmConfigpost-initcheck uses this to skip the blanket HMA-disable that would otherwise fire whenever a
kv_transfer_configis present. Keeping HMA enabled is required for hybrid models(sliding-window attention, Mamba) to use their separate KV cache groups correctly.
High level description of each change
vllm/v1/kv_cache_interface.pyCacheOnlySpec; guardUniformTypeKVCacheSpecs.is_uniform_typeagainst itvllm/v1/core/single_type_kv_cache_manager.pyCacheOnlyManager(prefix cache disabled); register inspec_manager_mapvllm/v1/core/kv_cache_utils.pyget_kv_cache_groups; account for CacheOnly memory inget_kv_cache_config_from_groupsand_max_memory_usage_bytes_from_groupsvllm/model_executor/models/extract_hidden_states.pyCacheOnlySpecfromget_kv_cache_spec()instead ofMLAAttentionSpecvllm/v1/worker/gpu_model_runner.py_get_slot_mappings(real block table, 0-padded); add CacheOnly to block table in_reinitialize_input_batchviaiter(kernel_block_sizes)decouplingvllm/v1/spec_decode/extract_hidden_states.pycache_only_slot_mappingfrom the fullslot_mappingsdict inpropose()vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.pySupportsHMA; detect CacheOnly group index dynamically via_cache_group_idx; addrequest_finished_all_groups; fixblock_idsindexing to use_cache_group_idxvllm/config/vllm.pysupports_hma()check so connectors implementingSupportsHMAkeep HMA enabledtests/v1/core/test_kv_cache_utils.pyCacheOnlySpecpage size, group routing, and memory budgetTest plan
Unit tests
.venv/bin/python -m pytest tests/v1/core/test_kv_cache_utils.py \ -k "cache_only" -vManual end-to-end verification
Manually verified with
examples/offline_inference/extract_hidden_states.pyon botha non-hybrid and a hybrid model:
Hybrid (non-uniform attention):
Non-hybrid (uniform full attention):
Both runs produced correctly shaped
hidden_statestensors saved to safetensors, withtoken_idsmatching the prompt token IDs.Checklist
isinstance(spec, CacheOnlySpec))