[Spec Decode] Support hybrid attention models in extract_hidden_states#39949
Open
mgoin wants to merge 1 commit intovllm-project:mainfrom
Open
[Spec Decode] Support hybrid attention models in extract_hidden_states#39949mgoin wants to merge 1 commit intovllm-project:mainfrom
mgoin wants to merge 1 commit intovllm-project:mainfrom
Conversation
Contributor
There was a problem hiding this comment.
Code Review
This pull request introduces the HiddenStateCacheSpec to support hidden-state extraction within the vLLM V1 engine. Key changes include updating the KV cache grouping heuristics to prevent singleton cache-only layers from collapsing group sizes, refactoring the ExampleHiddenStatesConnector to utilize attn_metadata.slot_mapping directly, and implementing dynamic HMA (Hybrid Memory Architecture) support checks for connectors. Feedback is provided regarding the max_memory_usage_bytes implementation in HiddenStateCacheSpec, which currently fails to account for context parallelism, potentially leading to memory over-estimation during initialization.
Hidden-state extraction breaks on hybrid-attention models (e.g. Qwen3.5) because kv_transfer_config force-disables HMA and unify_hybrid_kv_cache_specs cannot fold MambaSpec into a uniform type. Fix by gating HMA-disable on supports_hma(connector_cls), making ExampleHiddenStatesConnector a SupportsHMA subclass, and handling the cache-only layer's page alignment for hybrid models. Key changes: - HiddenStateCacheSpec: thin marker subclass of MLAAttentionSpec (inherits all dispatch behavior, no overrides). Defined in kv_cache_interface.py, registered in spec_manager_map. - get_kv_cache_groups: filter HiddenStateCacheSpec out before unify/grouping, add back as 1-layer group with page_size_padded aligned to the common page. General sub-functions untouched. - gpu_model_runner: as_strided reshape branch for padded specs (page_size_padded > real_page), proposer isinstance for kv_cache_gid. - Connector: read slot_mapping from attn_metadata (not scheduler block_ids), remove dead ReqMeta.slot_mapping field. - Proposer: kv_cache_gid for correct common_attn_metadata selection. - basic_cache/extract_from_kv_cache: block/offset indexing instead of flatten (works on non-contiguous strided tensors). Verified: Llama integration test + Qwen3.5-4B end-to-end on GPU. Signed-off-by: mgoin <mgoin64@gmail.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
12019e0 to
530539a
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Hidden-state extraction (
extract_hidden_statesspeculative method) currently doesn't work on hybrid-attention models like Qwen3.5 (GatedDeltaNet + full attention). The failure chain:kv_transfer_configis set → HMA unconditionally force-disabled inVllmConfig.__post_init__unify_hybrid_kv_cache_specstries to fold all specs into one type → can't handleMambaSpecalongside attention specs →ValueErrorThis PR fixes the issue by letting connectors that declare
SupportsHMAkeep HMA enabled, and teaching the KV cache grouping to handle the cache-only hidden-state layer alongside hybrid attention groups.Approach
Marker spec class —
HiddenStateCacheSpecis a thin subclass ofMLAAttentionSpecwith no behavioral overrides. It exists purely as a type tag soget_kv_cache_groupscan identify cache-only layers. Because it inherits fromMLAAttentionSpec→FullAttentionSpec, it passes through all existingisinstancechecks,is_uniform_type,spec_manager_map, andfind_longest_cache_hitwithout any changes to those paths.Filter-before-group, add-back-after — For hybrid models, the cache-only layer's page size (determined by
num_aux_hidden_states × hidden_size) generally won't divide evenly into the Mamba-aligned common page. Rather than modifying the unification or grouping algorithms,get_kv_cache_groupsfiltersHiddenStateCacheSpeclayers out before callingunify_kv_cache_spec_page_sizeand_get_kv_cache_groups_uniform_page_size(both untouched), then adds them back as their own 1-layer groups withblock_sizeshrunk andpage_size_paddedaligned to the common page.Strided tensor reshape — The page padding means the allocated tensor has gaps between blocks.
gpu_model_runner._reshape_kv_cache_tensorsgets anas_stridedbranch (guarded bypage_size_padded > real_page_size_bytes) that sets the block-level stride to span the full padded page, matching the patternMambaSpecalready uses for its state tensors.Connector fixes —
ExampleHiddenStatesConnectornow inheritsSupportsHMAand readsslot_mappingdirectly fromattn_metadatainstead of recomputing it from scheduler block IDs (which use the wrong block size under HMA). The now-deadReqMeta.slot_mappingfield and its per-request CPU tensor allocation are removed.Proposer group selection —
ExtractHiddenStatesProposerrecords itskv_cache_gidinvalidate_same_kv_cache_groupso the model runner selects the correctcommon_attn_metadatafor the cache-only group, matching the existingEagleProposerpattern.Block/offset cache ops —
basic_cacheandextract_from_kv_cacheuseslot_mapping // block_size/slot_mapping % block_sizeindexing instead of.view()/.flatten(), which works on non-contiguous (strided) tensors.Test plan
tests/v1/kv_connector/extract_hidden_states_integration/test_extraction.py— Llama end-to-end (GPU)extract_hidden_states— hybrid model end-to-end (GPU), hidden states shape[N, 3, 2560]with non-zero valuespre-commit run ruff-check / ruff-format / mypy-3.10— all passing🤖 AI-assisted (Claude)