Add support for Hybrid Models in extract hidden states#159
Draft
rahul-tuli wants to merge 6 commits intomainfrom
Draft
Add support for Hybrid Models in extract hidden states#159rahul-tuli wants to merge 6 commits intomainfrom
rahul-tuli wants to merge 6 commits intomainfrom
Conversation
Member
rahul-tuli
commented
Apr 9, 2026
- Add CacheOnlySpec(MLAAttentionSpec) to kv_cache_interface.py so it duck-types through all existing AttentionSpec code paths
- Pre-filter CacheOnlySpec in get_kv_cache_groups() before type- unification routing to prevent crashes with mixed spec types
- Joint budget calculation in get_kv_cache_config_from_groups() via extra_bytes_per_block parameter on get_num_blocks()
- Gate HMA disable in config with supports_hma() check so hybrid models keep their per-group block allocators
- Add SupportsHMA to ExampleHiddenStatesConnector with correct cache_group_idx for block_ids
- Resolve CacheOnly slot_mapping from per-layer mappings in the proposer instead of using main group's common_attn_metadata
- Add CacheOnlySpec(MLAAttentionSpec) to kv_cache_interface.py so it
duck-types through all existing AttentionSpec code paths
- Pre-filter CacheOnlySpec in get_kv_cache_groups() before type-
unification routing to prevent crashes with mixed spec types
- Joint budget calculation in get_kv_cache_config_from_groups() via
extra_bytes_per_block parameter on get_num_blocks()
- Gate HMA disable in config with supports_hma() check so hybrid
models keep their per-group block allocators
- Add SupportsHMA to ExampleHiddenStatesConnector with correct
cache_group_idx for block_ids
- Resolve CacheOnly slot_mapping from per-layer mappings in the
proposer instead of using main group's common_attn_metadata
Signed-off-by: Rahul-Tuli <rtuli@redhat.com>
…V cache groups CacheOnly layers (used for hidden-state extraction) previously created a separate KV cache group, which caused issues with hybrid model support: scattered isinstance checks, memory accounting hacks (extra_bytes_per_block), and potential 3-way hybrid coordinator failures. This commit makes CacheOnly layers "supplementary" -- they piggyback on group 0's block table and slot mappings instead of being full group participants. They still get their own allocated tensors but are invisible to the KV cache coordinator. This eliminates the need for separate block management while keeping the memory accounting clean. Key changes: - Add supplementary_specs field to KVCacheConfig - Strip CacheOnly from groups in get_kv_cache_groups() via split_supplementary_specs() - Account for supplementary memory in get_kv_cache_config_from_groups() budget - Handle supplementary layers in attn_utils (init_attn_backend, _allocate_kv_cache, _reshape_kv_cache) - Revert proposer to use group 0's slot_mapping directly - Remove _cache_group_idx from ExampleHiddenStatesConnector - Remove CacheOnlySpec from single_type_kv_cache_manager spec_manager_map Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Rahul-Tuli <rtuli@redhat.com>
…eshape The model runner has its own _allocate_kv_cache_tensors and _reshape_kv_cache_tensors methods (separate from attn_utils.py). These were missing supplementary layer support: 1. _allocate_kv_cache_tensors: assertion didn't include supplementary layer names, causing "Some layers are not correctly initialized" 2. _reshape_kv_cache_tensors: supplementary layers were never reshaped or included in kv_caches, so they wouldn't be bound Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Rahul-Tuli <rtuli@redhat.com>
CacheOnlyAttentionLayer captured cache_config.block_size at __init__ time (16), but for hybrid models with Mamba layers the block_size is later adjusted upward (e.g. to 528) to match the Mamba page size. Since CacheOnly layers share slot_mapping with group 0, slot indices are computed as block_id * 528 + offset, but the CacheOnly KV cache was shaped with block_size=16 — causing CUDA "index out of bounds". Fix: read vllm_config.cache_config.block_size at get_kv_cache_spec() time instead of using the stale self.block_size from init. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Rahul-Tuli <rtuli@redhat.com>
Use the GPU slot_mapping from the attention metadata when extracting hidden states from the KV cache, rather than recomputing it from scheduler block IDs. On hybrid models (e.g. Qwen3.5), kernel block splitting causes the two to diverge, producing all-zeros extraction. Also map supplementary CacheOnly layers to group 0's slot_mapping in both build_slot_mappings_by_layer and _get_slot_mappings, since they share group 0's block table. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Rahul-Tuli <rtuli@redhat.com>
Replace the inlined 37-line reshape block in gpu_model_runner's _reshape_kv_cache_tensors with a call to the shared _reshape_one_layer helper already factored out in attn_utils. Include supplementary specs (e.g. CacheOnly) in the memory budget used by _max_memory_usage_bytes_from_groups, _estimate_max_model_len, and _auto_fit_max_model_len so that the "enough memory" check and auto-fit context length estimation account for supplementary tensors. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Rahul-Tuli <rtuli@redhat.com>
mgoin
added a commit
that referenced
this pull request
Apr 15, 2026
Hidden-state extraction currently breaks on hybrid-attention models
(Mamba + full attention, e.g. Qwen3.5) because `kv_transfer_config` is
set → `disable_hybrid_kv_cache_manager=True` →
`unify_hybrid_kv_cache_specs` fails to fold `MambaSpec` into a uniform
type. This change unblocks the path with a set of minimal, scoped fixes:
- Add a proper `HiddenStateCacheSpec` as a sibling of `FullAttentionSpec`
(not a subclass) so the cache-only layer no longer abuses
`MLAAttentionSpec` for its page-size math. Sibling placement avoids
tripping isinstance checks written for real attention layers and
avoids inheriting MLA-specific fields like `cache_dtype_str`.
- Teach `UniformTypeKVCacheSpecs.is_uniform_type` to accept mixed
`{FullAttentionSpec, HiddenStateCacheSpec}` groups, keeping Llama +
extract_hidden_states on the existing single-group allocator path.
- Register `HiddenStateCacheSpec` in `spec_manager_map` (reuses
`FullAttentionManager` — same "one slot per token" semantics).
- Gate the `kv_transfer_config → disable HMA` logic in `VllmConfig` on
`supports_hma(connector_cls)`, so `SupportsHMA` connectors keep HMA on.
- Mark `ExampleHiddenStatesConnector` as `SupportsHMA` and implement
`request_finished_all_groups`.
- Read `slot_mapping` from `attn_metadata.slot_mapping` in
`save_kv_layer` instead of recomputing from scheduler block IDs: the
scheduler's block size may diverge from the cache-only layer's under
HMA. This also deletes the now-dead `ReqMeta.slot_mapping` field and
its per-request CPU tensor allocation.
- Fix a grouping heuristic in `_get_kv_cache_groups_uniform_page_size`:
singleton `HiddenStateCacheSpec` buckets now skip the min/max
computation so they don't collapse `group_size` to 1 and fan out the
main attention into one-layer-per-group.
- Re-read `cache_config.block_size` in `CacheOnlyAttentionLayer.get_kv_cache_spec`
rather than the stale value cached at `__init__` time (hybrid models
bump it after model loading).
Llama + extract_hidden_states stays on its current `UniformTypeKVCacheSpecs`
single-group path. Qwen3.5 + extract_hidden_states now reaches the HMA
multi-group path; any remaining page-size-unification divisibility
issues are model-config dependent and can be addressed as follow-ups.
This is an alternative design to #159
that avoids introducing `supplementary_specs` plumbing throughout
`kv_cache_utils.py`. The goal is roughly equivalent runtime behavior for
supported configurations with substantially less code surface.
Test plan:
pytest tests/v1/core/test_kv_cache_utils.py
-> 59 passed (50 pre-existing + 9 new)
pre-commit run ruff-check / ruff-format / mypy-3.10
-> all passing
Note: this change was developed with AI assistance (Claude). Every line
has been reviewed and the test plan executed locally.
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
4 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.