Skip to content

nnx.cached_partial breaks nnx.Module.sow -> nnx.pop pattern #5130

@am001122

Description

@am001122

The nnx.cached_partial docs explicitly call out the ability to use sow and pop:

Temporary mutations are allowed (e.g. the use of Module.sow) as long as they are cleaned up before the function returns (e.g. via nnx.pop).

However the following minimal example (with some added debug prints) does not work:

import jax.numpy as jnp
from flax import nnx


class Mod(nnx.Module):
    def __init__(self, rngs: nnx.Rngs):
        self.linear = nnx.Linear(4, 4, rngs=rngs)

    def __call__(self, x):
        y = self.linear(x)
        self.sow(nnx.Intermediate, "my_summary", y.mean())
        return y * 2


@nnx.jit
def train_step(model: Mod, x):
    before_call = nnx.graphdef(model)
    print(f"{before_call==nnx.graphdef(model)=}")
    out = model(x)
    after_call = nnx.graphdef(model)
    intermediates = nnx.pop(model, nnx.Intermediate)
    after_pop = nnx.graphdef(model)
    print(
        f"{before_call==after_call=}, {before_call==after_pop=}, {after_call==after_pop=}"
    )
    print(
        f"{before_call.nodes==after_pop.nodes=}, {before_call.num_leaves==after_pop.num_leaves=}"
    )
    print(f"{before_call.attributes==after_pop.attributes=}, differing elements:")
    print(
        "\n".join(
            f"{e1} != {e3}"
            for e1, e3 in zip(before_call.attributes, after_pop.attributes)
            if e1 != e3
        )
    )
    return out, intermediates


x = jnp.ones((2, 4))
model = Mod(nnx.Rngs(42))
train_step_fn = nnx.cached_partial(train_step, model)
out, inter = train_step_fn(x)

produces

before_call==nnx.graphdef(model)=True
before_call==after_call=False, before_call==after_pop=False, after_call==after_pop=False
before_call.nodes==after_pop.nodes=True, before_call.num_leaves==after_pop.num_leaves=True
before_call.attributes==after_pop.attributes=False, differing elements:
('_pytree__nodes', Static(value={'_pytree__state': True, 'linear': True, '_pytree__nodes': False})) != ('_pytree__nodes', Static(value={'_pytree__state': True, 'linear': True, '_pytree__nodes': False, 'my_summary': True}))
...
File ".../lib/python3.13/site-packages/flax/nnx/graph.py", line 1912, in unflatten
    raise ValueError(
    ...<2 lines>...
    )
ValueError: The graph structure of a node added to cached_partial was mutated inside the transformation, this is not allowed.
...
Full traceback

Traceback (most recent call last):
    out, inter = train_step_fn(x)
                 ~~~~~~~~~~~~~^^^
  File ".../lib/python3.13/site-packages/flax/nnx/graph.py", line 1654, in cache_args_wrapper
    return f(*cached_args, *args, **kwargs)
  File ".../lib/python3.13/site-packages/flax/nnx/transforms/compilation.py", line 477, in __call__
    out = self._get_non_pure_out(pure_args_out, pure_kwargs_out, pure_out)
  File ".../lib/python3.13/site-packages/flax/nnx/transforms/compilation.py", line 462, in _get_non_pure_out
    _args_out, _kwargs_out, out = extract.from_tree(
                                  ~~~~~~~~~~~~~~~~~^
      (pure_args_out, pure_kwargs_out, pure_out),
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    ...<2 lines>...
      ctxtag=self,
      ^^^^^^^^^^^^
    )
    ^
  File ".../lib/python3.13/site-packages/flax/nnx/extract.py", line 301, in from_tree
    return jax.tree.map(maybe_split, tree, is_leaf=is_leaf)
           ~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.13/site-packages/jax/_src/tree.py", line 155, in map
    return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
           ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.13/site-packages/jax/_src/tree_util.py", line 361, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
           ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.13/site-packages/jax/_src/tree_util.py", line 361, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
                             ~^^^^^
  File ".../lib/python3.13/site-packages/flax/nnx/extract.py", line 298, in maybe_split
    return merge_fn(merge_ctx, (), prefix, x)
  File ".../lib/python3.13/site-packages/flax/nnx/transforms/compilation.py", line 110, in _jit_merge_fn
    return ctx.unflatten(leaf.graphdef, *leaf.states)
           ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.13/site-packages/flax/nnx/graph.py", line 1912, in unflatten
    raise ValueError(
    ...<2 lines>...
    )
ValueError: The graph structure of a node added to cached_partial was mutated inside the transformation, this is not allowed.
Node: Mod( # Param: 20 (80 B)
  linear=Linear( # Param: 20 (80 B)
    bias=Param( # 4 (16 B)
      value=Array(shape=(4,), dtype=dtype('float32'))
    ),
    dot_general=<function dot_general at 0x7072b761de40>,
    dtype=None,
    in_features=4,
    kernel=Param( # 16 (64 B)
      value=Array(shape=(4, 4), dtype=dtype('float32'))
    ),
    out_features=4,
    param_dtype=float32,
    precision=None,
    preferred_element_type=None,
    promote_dtype=<function promote_dtype at 0x7072b4392160>,
    use_bias=True
  )
)
Ouput graphdef: GraphDef(nodes=[NodeDef(
  type='Mod',
  index=0,
  outer_index=0,
  num_attributes=3,
  metadata=Mod
), NodeDef(
  type='GenericPytree',
  index=None,
  outer_index=None,
  num_attributes=0,
  metadata=({}, PyTreeDef(CustomNode(PytreeState[(False, False)], [])))
), NodeDef(
  type='Linear',
  index=1,
  outer_index=1,
  num_attributes=13,
  metadata=Linear
), NodeDef(
  type='GenericPytree',
  index=None,
  outer_index=None,
  num_attributes=0,
  metadata=({}, PyTreeDef(CustomNode(PytreeState[(False, False)], [])))
), VariableDef(
  type='Param',
  index=2,
  outer_index=2,
  metadata=PrettyMapping({
    'is_hijax': False,
    'has_ref': False,
    'is_mutable': True,
    'eager_sharding': True
  })
), NodeDef(
  type='NoneType',
  index=None,
  outer_index=None,
  num_attributes=0,
  metadata=None
), VariableDef(
  type='Param',
  index=3,
  outer_index=3,
  metadata=PrettyMapping({
    'is_hijax': False,
    'has_ref': False,
    'is_mutable': True,
    'eager_sharding': True
  })
), NodeDef(
  type='NoneType',
  index=None,
  outer_index=None,
  num_attributes=0,
  metadata=None
), NodeDef(
  type='NoneType',
  index=None,
  outer_index=None,
  num_attributes=0,
  metadata=None
)], attributes=[('_pytree__nodes', Static(value={'_pytree__state': True, 'linear': True, '_pytree__nodes': False, 'my_summary': True})), ('_pytree__state', NodeAttr()), ('linear', NodeAttr()), ('_pytree__nodes', Static(value={'_pytree__state': True, 'kernel': True, 'bias': True, 'in_features': False, 'out_features': False, 'use_bias': False, 'dtype': False, 'param_dtype': False, 'precision': False, 'dot_general': False, 'promote_dtype': False, 'preferred_element_type': False, '_pytree__nodes': False})), ('_pytree__state', NodeAttr()), ('bias', NodeAttr()), ('dot_general', Static(value=<function dot_general at 0x7072b761de40>)), ('dtype', NodeAttr()), ('in_features', Static(value=4)), ('kernel', NodeAttr()), ('out_features', Static(value=4)), ('param_dtype', Static(value=<class 'jax.numpy.float32'>)), ('precision', NodeAttr()), ('preferred_element_type', NodeAttr()), ('promote_dtype', Static(value=<function promote_dtype at 0x7072b4392160>)), ('use_bias', Static(value=True))], num_leaves=2)
Expected graphdef: GraphDef(nodes=[NodeDef(
  type='Mod',
  index=0,
  outer_index=0,
  num_attributes=3,
  metadata=Mod
), NodeDef(
  type='GenericPytree',
  index=None,
  outer_index=None,
  num_attributes=0,
  metadata=({}, PyTreeDef(CustomNode(PytreeState[(False, False)], [])))
), NodeDef(
  type='Linear',
  index=1,
  outer_index=1,
  num_attributes=13,
  metadata=Linear
), NodeDef(
  type='GenericPytree',
  index=None,
  outer_index=None,
  num_attributes=0,
  metadata=({}, PyTreeDef(CustomNode(PytreeState[(False, False)], [])))
), VariableDef(
  type='Param',
  index=2,
  outer_index=2,
  metadata=PrettyMapping({
    'is_hijax': False,
    'has_ref': False,
    'is_mutable': True,
    'eager_sharding': True
  })
), NodeDef(
  type='NoneType',
  index=None,
  outer_index=None,
  num_attributes=0,
  metadata=None
), VariableDef(
  type='Param',
  index=3,
  outer_index=3,
  metadata=PrettyMapping({
    'is_hijax': False,
    'has_ref': False,
    'is_mutable': True,
    'eager_sharding': True
  })
), NodeDef(
  type='NoneType',
  index=None,
  outer_index=None,
  num_attributes=0,
  metadata=None
), NodeDef(
  type='NoneType',
  index=None,
  outer_index=None,
  num_attributes=0,
  metadata=None
)], attributes=[('_pytree__nodes', Static(value={'_pytree__state': True, 'linear': True, '_pytree__nodes': False})), ('_pytree__state', NodeAttr()), ('linear', NodeAttr()), ('_pytree__nodes', Static(value={'_pytree__state': True, 'kernel': True, 'bias': True, 'in_features': False, 'out_features': False, 'use_bias': False, 'dtype': False, 'param_dtype': False, 'precision': False, 'dot_general': False, 'promote_dtype': False, 'preferred_element_type': False, '_pytree__nodes': False})), ('_pytree__state', NodeAttr()), ('bias', NodeAttr()), ('dot_general', Static(value=<function dot_general at 0x7072b761de40>)), ('dtype', NodeAttr()), ('in_features', Static(value=4)), ('kernel', NodeAttr()), ('out_features', Static(value=4)), ('param_dtype', Static(value=<class 'jax.numpy.float32'>)), ('precision', NodeAttr()), ('preferred_element_type', NodeAttr()), ('promote_dtype', Static(value=<function promote_dtype at 0x7072b4392160>)), ('use_bias', Static(value=True))], num_leaves=2)

This appears to be a bug: digging around a little seems to suggest that there is no mirror image to the _setattr logic here:

vars(self)['_pytree__nodes'] = self._pytree__nodes.update({name: data})

that would remove the entry for an attribute from _pytree__nodes if that attribute is removed from the module.

I think this is a regression, as I can produce this bug on flax 0.12.1 but not on the 0.10.7 pre-installed in Colab:

Version Environment Status
0.10.7 Colab pre-installed OK
0.11.2 Colab fails
0.12.0 Colab fails
0.12.1 Colab fails
0.12.1 Local fails

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions