Gemma4 extract hidden states support#158
Draft
rahul-tuli wants to merge 3 commits intocache-only-spec-hidden-statesfrom
Draft
Gemma4 extract hidden states support#158rahul-tuli wants to merge 3 commits intocache-only-spec-hidden-statesfrom
rahul-tuli wants to merge 3 commits intocache-only-spec-hidden-statesfrom
Conversation
Gemma4ForCausalLM previously lacked SupportsEagle3, causing a RuntimeError when extract_hidden_states speculative decoding attempted to call set_aux_hidden_state_layers on the model. Changes: - Gemma4Model inherits EagleModelMixin to gain aux_hidden_state_layers and _maybe_add_hidden_state - Gemma4Model.forward() (normal path) collects aux hidden states per layer via _maybe_add_hidden_state and returns (hidden_states, aux_hidden_states) when any are collected; fast_prefill path is unchanged (kv_sharing_fast_prefill is not used with extract_hidden_states) - Gemma4ForCausalLM inherits SupportsEagle3, satisfying the supports_eagle3() check in gpu_model_runner before set_aux_hidden_state_layers is called - Inline actual_layer_idx computation unconditionally in the layer loop (was gated on per_layer_inputs is not None before) Signed-off-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Claude
The speculative config validator has an allowlist for models that
support aux hidden state outputs (used by eagle3, extract_hidden_states,
and dflash). Gemma4 now implements SupportsEagle3 via EagleModelMixin,
so add 'gemma4' to the list to unblock the validation.
The check uses substring matching against model_type ('gemma4_text'
contains 'gemma4'), which is consistent with how other model types
are matched (e.g. 'qwen' matches 'qwen3_moe', 'qwen2', etc.).
Signed-off-by: Rahul Tuli <rtuli@redhat.com>
Co-authored-by: Claude
The resolved architecture for Gemma4 is Gemma4ForConditionalGeneration (the multimodal wrapper in gemma4_mm.py). The supports_eagle3() check in gpu_model_runner runs on this outer wrapper, so adding SupportsEagle3 only to Gemma4ForCausalLM was insufficient. SupportsEagle3.set_aux_hidden_state_layers delegates via get_language_model() → Gemma4ForCausalLM.model (Gemma4Model, EagleModelMixin), so no extra implementation is needed in the wrapper. Co-Authored-By: Claude <noreply@anthropic.com> Signed-off-by: Rahul-Tuli <rtuli@redhat.com>
2 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Output:
Prompt: Generate a sentence with hidden states Prompt token ids: [43204, 496, 13315, 607, 11497, 5022] Prompt hidden states path: /tmp/tmpgtw8bf1u/0-87d1df37.safetensors Extracted token ids: tensor([43204, 496, 13315, 607, 11497, 5022]) Extracted hidden states shape: torch.Size([6, 1, 5376]) Extracted hidden states: tensor([[[ 0.0063, 0.0258, -0.0055, ..., 0.0058, 0.0134, -0.0349]], [[-0.0293, -0.0229, 0.0025, ..., -0.0054, 0.0320, 0.0386]], [[-0.0044, 0.0017, -0.0014, ..., 0.0013, -0.0069, -0.0364]], [[ 0.0036, 0.0081, -0.0023, ..., 0.0014, 0.0033, 0.0101]], [[-0.0024, 0.0109, 0.0077, ..., 0.0026, 0.0049, 0.0079]], [[ 0.0055, 0.0055, -0.0080, ..., 0.0040, 0.0160, -0.0110]]], dtype=torch.bfloat16) Prompt: Write a python function Prompt token ids: [6974, 496, 23181, 1292] Prompt hidden states path: /tmp/tmpgtw8bf1u/1-a76e7e39.safetensors Extracted token ids: tensor([ 6974, 496, 23181, 1292]) Extracted hidden states shape: torch.Size([4, 1, 5376]) Extracted hidden states: tensor([[[-0.0028, 0.0074, 0.0009, ..., -0.0049, -0.0020, -0.0183]], [[-0.0269, -0.0239, 0.0028, ..., -0.0037, 0.0291, 0.0369]], [[-0.0063, 0.0049, -0.0017, ..., -0.0007, 0.0024, 0.0199]], [[-0.0072, 0.0078, 0.0041, ..., 0.0021, 0.0100, 0.0160]]], dtype=torch.bfloat16)