Skip to content

.sow() intermediate values don't enable Dead Code Elimination of unused outputs #1049

@mdda

Description

@mdda

Summary

When using output_hidden_states=True with the Gemma3 model to access intermediate hidden states via .sow(), the final logits computation is still materialized even when the logits return value is ignored. This prevents efficient "headless" inference where only embeddings are needed (or, for my case, top-k logit extraction without full logit computation).

Context

The Gemma3 Transformer.__call__ method supports outputting hidden states via .sow():

(https://github.com/google/tunix/blob/main/tunix/models/gemma3/model.py#L966)

if output_hidden_states:
    self.sow(nnx.Intermediate, 'all_hidden_states', x)
logits = self.embedder.decode(x)
return logits, new_cache

Expected Behavior

When wrapping the model to only extract the sowed all_hidden_states and ignoring the returned logits, JAX's Dead Code Elimination (DCE) should recognize that:

  1. The logits are never used
  2. The embedder.decode(x) matmul can be eliminated
  3. Only the hidden states need to be computed

This would save significant memory and compute, since the embedding→vocab matmul (e.g., [B, L, hidden] @ [hidden, 262144]) is one of the largest operations.

Actual Behavior

The logits matmul is still present in the compiled jaxpr, even when the logits are completely ignored. DCE does not eliminate the unused computation.

Steps to Reproduce the Problem

I created a simplified 2-layer NNX model to isolate the issue:

import jax
import jax.numpy as jnp
from flax import nnx

DIM_INPUT = 10
DIM_HIDDEN = 20
DIM_OUTPUT = 1000  # Large - should be DCE'd if unused

class SimpleTwoLayerModel(nnx.Module):
    def __init__(self, rngs: nnx.Rngs):
        self.layer1 = nnx.Linear(DIM_INPUT, DIM_HIDDEN, rngs=rngs)
        self.layer2 = nnx.Linear(DIM_HIDDEN, DIM_OUTPUT, rngs=rngs)

    def __call__(self, x, output_hidden=False):
        hidden = nnx.relu(self.layer1(x))
        if output_hidden:
            self.sow(nnx.Intermediate, 'hidden_state', hidden)
        output = self.layer2(hidden)
        return output

class HeadlessWrapper(nnx.Module):
    def __init__(self, model):
        self.model = model

    def __call__(self, x):
        _ = self.model(x, output_hidden=True)  # Ignore output!
        sow_state = nnx.pop(self.model, nnx.Intermediate)
        return sow_state['hidden_state'].value[0]

# Test
rngs = nnx.Rngs(0)
base_model = SimpleTwoLayerModel(rngs)
headless = HeadlessWrapper(base_model)

@nnx.jit
def run_headless(model, x):
    return model(x)

x = jnp.ones((1, DIM_INPUT))
jaxpr = jax.make_jaxpr(run_headless)(headless, x)
print(jaxpr)  # Shows matmul with DIM_OUTPUT still present

Colab notebook with full test:

https://colab.research.google.com/drive/1-gAlL5fHwZvaEVskHtiCfAz0xsLiElKu?usp=sharing

Environment

per provided Colab:

Python 3.12.12
flax==0.12.3
jax==0.9.0

Checklist

  • [Y] I have searched the existing issues for a similar bug report.
  • [Y] I have provided all the required information in the "Environment" section.
  • [Y] I have provided a minimal, reproducible example.

Would you like to help us fix it?

It seems that (unless there is something obvious that I'm missing) this could be an upstream nnx (or jax) issue, with the resolution largely in Google's hands - but any feedback here would be welcome.

Metadata

Metadata

Assignees

No one assigned

    Labels

    type:bugSomething isn't workingtype:performanceMake things lean and fast

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions