Skip to content

Gemma4 extract hidden states support#158

Draft
rahul-tuli wants to merge 3 commits intocache-only-spec-hidden-statesfrom
cache-only-spec-gemma4-support
Draft

Gemma4 extract hidden states support#158
rahul-tuli wants to merge 3 commits intocache-only-spec-hidden-statesfrom
cache-only-spec-gemma4-support

Conversation

@rahul-tuli
Copy link
Copy Markdown
Member

@rahul-tuli rahul-tuli commented Apr 8, 2026

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile

from safetensors import safe_open

from vllm import LLM, SamplingParams


def run_vllm():
    with tempfile.TemporaryDirectory() as tmpdirname:
        llm = LLM(
            model="google/gemma-4-31B-it",
            tensor_parallel_size=4,
            speculative_config={
                "method": "extract_hidden_states",
                "num_speculative_tokens": 1,
                "draft_model_config": {
                    "hf_config": {
                        "eagle_aux_hidden_state_layer_ids": [  
                            2
                        ],
                    }
                },
            },
            kv_transfer_config={
                "kv_connector": "ExampleHiddenStatesConnector",
                "kv_role": "kv_producer",
                "kv_connector_extra_config": {
                    "shared_storage_path": tmpdirname,
                },
            },
            disable_hybrid_kv_cache_manager=False,
        )

        prompts = ["Generate a sentence with hidden states", "Write a python function"]
        sampling_params = SamplingParams(max_tokens=1)
        outputs = llm.generate(prompts, sampling_params)

        for output in outputs:
            print("\nPrompt:", output.prompt)
            print("Prompt token ids:", output.prompt_token_ids)

            hidden_states_path = output.kv_transfer_params.get("hidden_states_path")
            assert hidden_states_path is not None
            print("Prompt hidden states path:", hidden_states_path)

            with safe_open(hidden_states_path, "pt") as f:
                token_ids = f.get_tensor("token_ids")
                hidden_states = f.get_tensor("hidden_states")

                print("Extracted token ids:", token_ids)  # Matches prompt token ids
                print(
                    "Extracted hidden states shape:", hidden_states.shape
                )  # [prompt len, num_hidden_layers, hidden size]
                print("Extracted hidden states:", hidden_states)

if __name__ == "__main__":
    run_vllm()

Output:

Prompt: Generate a sentence with hidden states
Prompt token ids: [43204, 496, 13315, 607, 11497, 5022]
Prompt hidden states path: /tmp/tmpgtw8bf1u/0-87d1df37.safetensors
Extracted token ids: tensor([43204,   496, 13315,   607, 11497,  5022])
Extracted hidden states shape: torch.Size([6, 1, 5376])
Extracted hidden states: tensor([[[ 0.0063,  0.0258, -0.0055,  ...,  0.0058,  0.0134, -0.0349]],

        [[-0.0293, -0.0229,  0.0025,  ..., -0.0054,  0.0320,  0.0386]],

        [[-0.0044,  0.0017, -0.0014,  ...,  0.0013, -0.0069, -0.0364]],

        [[ 0.0036,  0.0081, -0.0023,  ...,  0.0014,  0.0033,  0.0101]],

        [[-0.0024,  0.0109,  0.0077,  ...,  0.0026,  0.0049,  0.0079]],

        [[ 0.0055,  0.0055, -0.0080,  ...,  0.0040,  0.0160, -0.0110]]],
       dtype=torch.bfloat16)

Prompt: Write a python function
Prompt token ids: [6974, 496, 23181, 1292]
Prompt hidden states path: /tmp/tmpgtw8bf1u/1-a76e7e39.safetensors
Extracted token ids: tensor([ 6974,   496, 23181,  1292])
Extracted hidden states shape: torch.Size([4, 1, 5376])
Extracted hidden states: tensor([[[-0.0028,  0.0074,  0.0009,  ..., -0.0049, -0.0020, -0.0183]],

        [[-0.0269, -0.0239,  0.0028,  ..., -0.0037,  0.0291,  0.0369]],

        [[-0.0063,  0.0049, -0.0017,  ..., -0.0007,  0.0024,  0.0199]],

        [[-0.0072,  0.0078,  0.0041,  ...,  0.0021,  0.0100,  0.0160]]],
       dtype=torch.bfloat16)

Gemma4ForCausalLM previously lacked SupportsEagle3, causing a
RuntimeError when extract_hidden_states speculative decoding attempted
to call set_aux_hidden_state_layers on the model.

Changes:
- Gemma4Model inherits EagleModelMixin to gain aux_hidden_state_layers
  and _maybe_add_hidden_state
- Gemma4Model.forward() (normal path) collects aux hidden states per
  layer via _maybe_add_hidden_state and returns (hidden_states,
  aux_hidden_states) when any are collected; fast_prefill path is
  unchanged (kv_sharing_fast_prefill is not used with extract_hidden_states)
- Gemma4ForCausalLM inherits SupportsEagle3, satisfying the
  supports_eagle3() check in gpu_model_runner before set_aux_hidden_state_layers
  is called
- Inline actual_layer_idx computation unconditionally in the layer loop
  (was gated on per_layer_inputs is not None before)

Signed-off-by: Rahul Tuli <rtuli@redhat.com>
Co-authored-by: Claude
The speculative config validator has an allowlist for models that
support aux hidden state outputs (used by eagle3, extract_hidden_states,
and dflash). Gemma4 now implements SupportsEagle3 via EagleModelMixin,
so add 'gemma4' to the list to unblock the validation.

The check uses substring matching against model_type ('gemma4_text'
contains 'gemma4'), which is consistent with how other model types
are matched (e.g. 'qwen' matches 'qwen3_moe', 'qwen2', etc.).

Signed-off-by: Rahul Tuli <rtuli@redhat.com>
Co-authored-by: Claude
The resolved architecture for Gemma4 is Gemma4ForConditionalGeneration
(the multimodal wrapper in gemma4_mm.py). The supports_eagle3() check
in gpu_model_runner runs on this outer wrapper, so adding SupportsEagle3
only to Gemma4ForCausalLM was insufficient.

SupportsEagle3.set_aux_hidden_state_layers delegates via get_language_model()
→ Gemma4ForCausalLM.model (Gemma4Model, EagleModelMixin), so no extra
implementation is needed in the wrapper.

Co-Authored-By: Claude <noreply@anthropic.com>

Signed-off-by: Rahul-Tuli <rtuli@redhat.com>
@rahul-tuli rahul-tuli changed the title Cache only spec gemma4 support Gemma4 extract hidden states support Apr 8, 2026
@MeganEFlynn MeganEFlynn mentioned this pull request Apr 9, 2026
2 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant