-
Notifications
You must be signed in to change notification settings - Fork 245
Description
Summary
When using output_hidden_states=True with the Gemma3 model to access intermediate hidden states via .sow(), the final logits computation is still materialized even when the logits return value is ignored. This prevents efficient "headless" inference where only embeddings are needed (or, for my case, top-k logit extraction without full logit computation).
Context
The Gemma3 Transformer.__call__ method supports outputting hidden states via .sow():
(https://github.com/google/tunix/blob/main/tunix/models/gemma3/model.py#L966)
if output_hidden_states:
self.sow(nnx.Intermediate, 'all_hidden_states', x)
logits = self.embedder.decode(x)
return logits, new_cacheExpected Behavior
When wrapping the model to only extract the sowed all_hidden_states and ignoring the returned logits, JAX's Dead Code Elimination (DCE) should recognize that:
- The logits are never used
- The
embedder.decode(x)matmul can be eliminated - Only the hidden states need to be computed
This would save significant memory and compute, since the embedding→vocab matmul (e.g., [B, L, hidden] @ [hidden, 262144]) is one of the largest operations.
Actual Behavior
The logits matmul is still present in the compiled jaxpr, even when the logits are completely ignored. DCE does not eliminate the unused computation.
Steps to Reproduce the Problem
I created a simplified 2-layer NNX model to isolate the issue:
import jax
import jax.numpy as jnp
from flax import nnx
DIM_INPUT = 10
DIM_HIDDEN = 20
DIM_OUTPUT = 1000 # Large - should be DCE'd if unused
class SimpleTwoLayerModel(nnx.Module):
def __init__(self, rngs: nnx.Rngs):
self.layer1 = nnx.Linear(DIM_INPUT, DIM_HIDDEN, rngs=rngs)
self.layer2 = nnx.Linear(DIM_HIDDEN, DIM_OUTPUT, rngs=rngs)
def __call__(self, x, output_hidden=False):
hidden = nnx.relu(self.layer1(x))
if output_hidden:
self.sow(nnx.Intermediate, 'hidden_state', hidden)
output = self.layer2(hidden)
return output
class HeadlessWrapper(nnx.Module):
def __init__(self, model):
self.model = model
def __call__(self, x):
_ = self.model(x, output_hidden=True) # Ignore output!
sow_state = nnx.pop(self.model, nnx.Intermediate)
return sow_state['hidden_state'].value[0]
# Test
rngs = nnx.Rngs(0)
base_model = SimpleTwoLayerModel(rngs)
headless = HeadlessWrapper(base_model)
@nnx.jit
def run_headless(model, x):
return model(x)
x = jnp.ones((1, DIM_INPUT))
jaxpr = jax.make_jaxpr(run_headless)(headless, x)
print(jaxpr) # Shows matmul with DIM_OUTPUT still presentColab notebook with full test:
https://colab.research.google.com/drive/1-gAlL5fHwZvaEVskHtiCfAz0xsLiElKu?usp=sharing
Environment
per provided Colab:
Python 3.12.12
flax==0.12.3
jax==0.9.0Checklist
- [Y] I have searched the existing issues for a similar bug report.
- [Y] I have provided all the required information in the "Environment" section.
- [Y] I have provided a minimal, reproducible example.
Would you like to help us fix it?
It seems that (unless there is something obvious that I'm missing) this could be an upstream nnx (or jax) issue, with the resolution largely in Google's hands - but any feedback here would be welcome.