Skip to content

feat: Support Hybrid Attention Models with extract_hidden_states#157

Draft
rahul-tuli wants to merge 4 commits intomainfrom
cache-only-spec-hidden-states
Draft

feat: Support Hybrid Attention Models with extract_hidden_states#157
rahul-tuli wants to merge 4 commits intomainfrom
cache-only-spec-hidden-states

Conversation

@rahul-tuli
Copy link
Copy Markdown
Member

This PR adds support for extracting hidden states with Hybrid Attention Models (for example: Qwen/Qwen3.5-9B)

The core issue was that CacheOnlyAttentionLayer — the single-layer "draft model"
used by extract_hidden_states — was previously registered with MLAAttentionSpec.
This caused two compounding bugs:

HMA was unconditionally disabled when any kv_transfer_config was set. With
MLAAttentionSpec, the CacheOnly layer appeared as a regular attention layer,
causing unify_hybrid_kv_cache_specs to fail on hybrid models (Qwen3.5,
Gemma4) where attention types cannot be unified. Even when HMA was turned on
unification with MLAAttentionSpec was not guaranteed to succeed. (failed for Gemma)

A dedicated CacheOnlySpec(KVCacheSpec) is introduced with semantics that match
hidden-state storage:

page_size_bytes = block_size × num_hidden_states × hidden_size × dtype_size_bytes

Because CacheOnlySpec does not inherit from AttentionSpec, all existing
isinstance(spec, AttentionSpec) guards naturally exclude it, and it is routed
through a dedicated code path in each subsystem.

Memory budget

CacheOnlySpec groups are pre-extracted from the main spec map before routing.
get_kv_cache_config_from_groups deducts their memory cost from the shared pool:

num_blocks = floor(available_memory / (main_bytes_per_block + cache_only_bytes_per_block))

Both the main KV cache tensors and the CacheOnly tensor are allocated for the same
num_blocks, ensuring that any block ID valid for the main cache is also valid for
the CacheOnly cache. The CacheOnly tensor is given its own independent KVCacheTensor
entry (not shared with attention layers).

HMA interaction

ExampleHiddenStatesConnector now declares SupportsHMA. The VllmConfig post-init
check uses this to skip the blanket HMA-disable that would otherwise fire whenever a
kv_transfer_config is present. Keeping HMA enabled is required for hybrid models
(sliding-window attention, Mamba) to use their separate KV cache groups correctly.


High level description of each change
File Change
vllm/v1/kv_cache_interface.py Add CacheOnlySpec; guard UniformTypeKVCacheSpecs.is_uniform_type against it
vllm/v1/core/single_type_kv_cache_manager.py Add CacheOnlyManager (prefix cache disabled); register in spec_manager_map
vllm/v1/core/kv_cache_utils.py Isolate CacheOnly groups before routing in get_kv_cache_groups; account for CacheOnly memory in get_kv_cache_config_from_groups and _max_memory_usage_bytes_from_groups
vllm/model_executor/models/extract_hidden_states.py Return CacheOnlySpec from get_kv_cache_spec() instead of MLAAttentionSpec
vllm/v1/worker/gpu_model_runner.py Add CacheOnly branch in _get_slot_mappings (real block table, 0-padded); add CacheOnly to block table in _reinitialize_input_batch via iter(kernel_block_sizes) decoupling
vllm/v1/spec_decode/extract_hidden_states.py Extract per-layer cache_only_slot_mapping from the full slot_mappings dict in propose()
vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py Declare SupportsHMA; detect CacheOnly group index dynamically via _cache_group_idx; add request_finished_all_groups; fix block_ids indexing to use _cache_group_idx
vllm/config/vllm.py Gate HMA-disable behind supports_hma() check so connectors implementing SupportsHMA keep HMA enabled
tests/v1/core/test_kv_cache_utils.py 5 new unit tests covering CacheOnlySpec page size, group routing, and memory budget

Test plan

Unit tests

.venv/bin/python -m pytest tests/v1/core/test_kv_cache_utils.py \
  -k "cache_only" -v

Manual end-to-end verification

Manually verified with examples/offline_inference/extract_hidden_states.py on both
a non-hybrid and a hybrid model:

Hybrid (non-uniform attention):

model="Qwen/Qwen3.5-9B", tensor_parallel_size=4
eagle_aux_hidden_state_layer_ids=[1, 2, 3, 4]

Non-hybrid (uniform full attention):

model="Qwen/Qwen3-8B", tensor_parallel_size=4
eagle_aux_hidden_state_layer_ids=[1, 2, 3, 4]

Both runs produced correctly shaped hidden_states tensors saved to safetensors, with
token_ids matching the prompt token IDs.


Checklist

  • All new code paths have corresponding unit tests
  • Existing test suite unaffected (new code paths gated behind isinstance(spec, CacheOnlySpec))
  • Pre-commit hooks pass on all touched files
  • Manual end-to-end verification on non-hybrid and hybrid models
  • AI assistance used (Claude); all changed lines reviewed by human submitter

Add CacheOnlySpec(KVCacheSpec) as a dedicated spec for
CacheOnlyAttentionLayer used by the extract_hidden_states speculative
decoding method. Previously MLAAttentionSpec was used as a proxy, which
caused two bugs:
- As an AttentionSpec subclass it entered unify_hybrid_kv_cache_specs,
  causing failures on hybrid models (Qwen3-MoE, Gemma4) where attention
  types cannot be unified.
- The CacheOnly group was incorrectly excluded from block-table and
  slot-mapping construction (handled in a follow-up commit).

CacheOnlySpec carries no factor-of-2 key/value split:
  page_size_bytes = block_size * num_hidden_states * hidden_size * dtype_bytes

Changes:
- vllm/v1/kv_cache_interface.py: add CacheOnlySpec; guard
  UniformTypeKVCacheSpecs.is_uniform_type against it
- vllm/v1/core/kv_cache_utils.py: pre-extract CacheOnly specs before
  routing in get_kv_cache_groups; deduct CacheOnly memory from shared
  budget in get_kv_cache_config_from_groups and
  _max_memory_usage_bytes_from_groups
- vllm/v1/core/single_type_kv_cache_manager.py: add CacheOnlyManager
  (prefix cache disabled); register in spec_manager_map
- vllm/model_executor/models/extract_hidden_states.py: return
  CacheOnlySpec from get_kv_cache_spec() instead of MLAAttentionSpec
- tests/v1/core/test_kv_cache_utils.py: 5 new tests covering page-size
  formula, group routing, hybrid routing, and memory budget

Co-authored-by: Claude

Signed-off-by: Rahul-Tuli <rtuli@redhat.com>
…ing, and reshape

CacheOnlySpec groups need a real block table entry and correct slot
mappings so hidden states are written to the right paged-block slots.
Previously these groups were skipped alongside EncoderOnlyAttentionSpec,
leaving input_batch.block_table without a CacheOnly entry (KeyError).

- _get_slot_mappings: split EncoderOnly/CacheOnly branches; for CacheOnly
  use the real block table and fill padding with 0 (not -1) — basic_cache
  uses plain tensor indexing and cannot handle negative slot IDs
- _reinitialize_input_batch: include CacheOnly in block_sizes so
  input_batch.block_table gains an entry; decouple kernel_block_sizes
  (which skips CacheOnly) using iter()/next()
- _reshape_kv_cache_tensors: reshape CacheOnly tensors to
  [num_blocks, block_size, num_hidden_states, hidden_size]
- prepare_kernel_block_sizes (worker/utils.py): skip CacheOnlySpec —
  it has no attention kernel and needs no kernel block size entry

Co-authored-by: Claude

Signed-off-by: Rahul-Tuli <rtuli@redhat.com>
propose() was passing common_attn_metadata.slot_mapping (the main
attention group's slot mapping, which uses -1 padding and different
block IDs) to _get_slot_mapping. This caused hidden states to be written
to incorrect or out-of-bounds slots in the CacheOnly KV cache.

Extract the slot mapping for the CacheOnly layer from the per-layer
slot_mappings dict (slot_mappings[attn_layer_names[0]]) so the proposer
writes hidden states to the same block slots that the KV connector reads
from in extract_from_kv_cache.

Co-authored-by: Claude

Signed-off-by: Rahul-Tuli <rtuli@redhat.com>
… CacheOnly

Without SupportsHMA, any kv_transfer_config triggers a blanket HMA
disable in VllmConfig._post_init. With CacheOnlySpec now as a separate
group, disabling HMA would cause unify_hybrid_kv_cache_specs to encounter
CacheOnlySpec alongside attention specs and fail. SupportsHMA opts the
connector out of that disable path.

Additional connector fixes required once CacheOnly is its own group:
- _cache_group_idx: detect the CacheOnly group index at __init__ so all
  block_ids accesses use the correct group (was hardcoded to [0])
- request_finished_all_groups: delegate to request_finished using
  block_ids[_cache_group_idx] (required by SupportsHMA contract)
- build_connector_meta / request_finished: use _cache_group_idx instead
  of literal 0 when reading block_ids

Co-authored-by: Claude

Signed-off-by: Rahul-Tuli <rtuli@redhat.com>
@rahul-tuli
Copy link
Copy Markdown
Member Author

Tracked in JIRA: https://issues.redhat.com/browse/INFERENG-5961

JIRA Details:

  • Issue Type: Task
  • Priority: Critical
  • Component: Speculators
  • Status: Backlog

for i, g in enumerate(groups)
if isinstance(g.kv_cache_spec, CacheOnlySpec)
),
0,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have 0 as a fallback here, but if a CacheOnlySpec isn't found doesn't that mean the model isn't setup correctly? Perhaps we should be failing here instead.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! You're correct, 0 is wrong here, raising an error is the right thing

for spec in kv_cache_specs.values()
)
elif isinstance(one_spec, CacheOnlySpec):
return False
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this always called with all layers (in which case it would be safe to assume there's at least two specs)?

If not, (if this is ever called with a subset of the layers, like just the draft layers), then we shouldn't assume this and should instead check the layers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants