Skip to content

Commit 124e33f

Browse files
Cristian GarciaFlax Authors
authored andcommitted
maintain data/static definition in split / state
`nnx.graph` APIs now follow the data / static definitions for `nnx.Pytree` instances. Previously `nnx.graph.flatten` match the default definitions for Pytree most of the time but deviated on edcases, now they always match. Example ```python class Foo(nnx.Pytree): def __init__(self, data, static): self.data = nnx.data(data) self.static = nnx.static(static) tree = Foo(1, 2) state = nnx.state(tree) # previously this was false assert 'data' in state ``` PiperOrigin-RevId: 866680531
1 parent d87f7da commit 124e33f

File tree

3 files changed

+85
-57
lines changed

3 files changed

+85
-57
lines changed

flax/nnx/graph.py

Lines changed: 58 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,13 @@ class NodeImplBase(tp.Generic[Node, Leaf, AuxData]):
179179
type: type[Node]
180180
flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[Key, Leaf]], AuxData]]
181181

182-
def node_dict(self, node: Node) -> dict[Key, Leaf]:
182+
def node_dict(self, node: Node) -> dict[Key, tp.Any]:
183183
nodes, _ = self.flatten(node)
184-
return dict(nodes)
184+
nodes = {
185+
key: node.value if isinstance(node, DataElem | StaticElem) else node
186+
for key, node in nodes
187+
}
188+
return nodes
185189

186190

187191
@dataclasses.dataclass(frozen=True, slots=True)
@@ -524,32 +528,21 @@ def __treescope_repr__(self, path, subtree_renderer):
524528

525529

526530
@dataclasses.dataclass(frozen=True, slots=True)
527-
class ArrayAttr:
528-
pass
529-
530-
531-
ARRAY_ATTR = ArrayAttr()
532-
533-
534-
@dataclasses.dataclass(frozen=True, slots=True)
535-
class MutableArrayAttr:
531+
class NodeAttr:
536532
pass
537533

538534

539-
MUTABLE_ARRAY_ATTR = MutableArrayAttr()
540-
535+
NODE_ATTR = NodeAttr()
541536

542537
@dataclasses.dataclass(frozen=True, slots=True)
543-
class NodeAttr:
538+
class LeafAttr:
544539
pass
545540

546-
547-
NODE_ATTR = NodeAttr()
541+
LEAF_ATTR = LeafAttr()
548542

549543
AttrType = tp.Union[
550544
NodeAttr,
551-
ArrayAttr,
552-
MutableArrayAttr,
545+
LeafAttr,
553546
'Static[tp.Any]',
554547
]
555548

@@ -711,6 +704,14 @@ def flatten( # type: ignore[invalid-annotation]
711704
else:
712705
return graphdef, leaves
713706

707+
@dataclasses.dataclass(frozen=True, slots=True)
708+
class DataElem:
709+
value: tp.Any
710+
711+
712+
@dataclasses.dataclass(frozen=True, slots=True)
713+
class StaticElem:
714+
value: tp.Any
714715

715716
def _graph_flatten(
716717
node: Node,
@@ -828,6 +829,18 @@ def make_mutable_arraydef(value: variablelib.Ref):
828829
nodes.append(nodedef)
829830

830831
for key, value in values:
832+
is_data = None
833+
if isinstance(value, DataElem):
834+
value = value.value
835+
is_data = True
836+
elif isinstance(value, StaticElem):
837+
value = value.value
838+
is_data = False
839+
840+
if is_data is False:
841+
attributes.append((key, Static(value)))
842+
continue
843+
831844
value_node_impl = get_node_impl(value)
832845
if path is not None:
833846
path.append(key)
@@ -845,15 +858,15 @@ def make_mutable_arraydef(value: variablelib.Ref):
845858
paths,
846859
)
847860
elif variablelib.is_array_ref(value):
848-
attributes.append((key, MUTABLE_ARRAY_ATTR))
861+
attributes.append((key, NODE_ATTR))
849862
array_refdef, leaf = make_mutable_arraydef(value)
850863
if not isinstance(leaf, Repeated):
851864
leaves.append(leaf)
852865
if paths is not None:
853866
paths.append(tuple(path)) # type: ignore
854867
nodes.append(array_refdef)
855-
elif isinstance(value, (jax.Array, np.ndarray)):
856-
attributes.append((key, ARRAY_ATTR))
868+
elif isinstance(value, (jax.Array, np.ndarray)) or is_data:
869+
attributes.append((key, LEAF_ATTR))
857870
if paths is not None:
858871
paths.append(tuple(path)) # type: ignore
859872
leaves.append(value)
@@ -1093,41 +1106,33 @@ def _get_children() -> list[tuple[Key, tp.Any]]:
10931106
key, value = next(attribute_iter)
10941107
if type(value) is Static:
10951108
children.append((key, value.value)) # type: ignore[attribute-error]
1096-
elif type(value) is MutableArrayAttr:
1097-
array_refdef = next(node_iter)
1098-
assert (
1099-
type(array_refdef) is ArrayRefDef or type(array_refdef) is NodeRef
1100-
)
1101-
if type(array_refdef) is NodeRef:
1102-
array_ref = index_ref[array_refdef.index]
1103-
else:
1104-
assert type(array_refdef) is ArrayRefDef
1109+
elif type(value) is LeafAttr:
1110+
leaf = next(leaves_iter)
1111+
children.append((key, leaf))
1112+
elif type(value) is NodeAttr:
1113+
node_def = next(node_iter)
1114+
if isinstance(node_def, NodeRef):
1115+
node = index_ref[node_def.index]
1116+
elif isinstance(node_def, ArrayRefDef):
11051117
leaf = next(leaves_iter)
1106-
array_ref = get_mutable_array(array_refdef, leaf)
1107-
children.append((key, array_ref))
1108-
elif type(value) is ArrayAttr:
1109-
array = next(leaves_iter)
1110-
children.append((key, array))
1118+
node = get_mutable_array(node_def, leaf)
1119+
elif isinstance(node_def, NodeDef | VariableDef):
1120+
value_node_impl = get_node_impl_for_type(node_def.type)
1121+
node = _graph_unflatten(
1122+
node_def,
1123+
value_node_impl,
1124+
node_iter,
1125+
attribute_iter,
1126+
leaves_iter,
1127+
index_ref,
1128+
outer_index_outer_ref,
1129+
copy_variables,
1130+
)
1131+
else:
1132+
raise RuntimeError(f'Unknown node definition: {node_def!r}')
1133+
children.append((key, node))
11111134
elif type(value) is NodeRef:
11121135
children.append((key, index_ref[value.index])) # type: ignore[attribute-error]
1113-
elif type(value) is NodeAttr:
1114-
# if the key is a subgraph we create an empty node
1115-
subgraphdef = next(node_iter)
1116-
if type(subgraphdef) is NodeDef:
1117-
value_node_impl = get_node_impl_for_type(subgraphdef.type) # type: ignore[attribute-error]
1118-
else:
1119-
value_node_impl = None
1120-
subnode = _graph_unflatten(
1121-
subgraphdef,
1122-
value_node_impl,
1123-
node_iter,
1124-
attribute_iter,
1125-
leaves_iter,
1126-
index_ref,
1127-
outer_index_outer_ref,
1128-
copy_variables,
1129-
)
1130-
children.append((key, subnode))
11311136
else:
11321137
raise RuntimeError(f'Unknown static field: {key!r}')
11331138

flax/nnx/pytreelib.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -949,7 +949,16 @@ def _pytree__unflatten(
949949
# Graph Definition
950950
# -------------------------
951951
def _graph_node_flatten(self):
952-
obj_items = vars(self).items()
952+
pytree_nodes = self._pytree__nodes
953+
obj_items = (
954+
(
955+
name,
956+
nnx.graph.DataElem(value)
957+
if name in pytree_nodes and pytree_nodes[name]
958+
else nnx.graph.StaticElem(value),
959+
)
960+
for name, value in vars(self).items()
961+
)
953962
if self._pytree__has_int_keys:
954963
obj_items = ((_maybe_int(name), value) for name, value in obj_items)
955964
key_fn = graph._type_aware_sort

tests/nnx/graph_utils_test.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ def __init__(self):
629629
self.assertFalse(hasattr(ctx, 'ctxtag'))
630630
self.assertIsInstance(graphdef1.nodes[0], nnx.graph.NodeDef)
631631
self.assertIsInstance(graphdef2.nodes[0], nnx.graph.NodeRef)
632-
self.assertLen(nnx.to_flat_state(state1), 1)
632+
self.assertLen(nnx.to_flat_state(state1), 2)
633633
self.assertLen(nnx.to_flat_state(state2), 0)
634634

635635
@jax.jit
@@ -717,7 +717,7 @@ def __init__(self):
717717
assert isinstance(t2, nnx.NodeStates)
718718
self.assertIsInstance(t1.graphdef.nodes[0], nnx.graph.NodeDef)
719719
self.assertIsInstance(t2.graphdef.nodes[0], nnx.graph.NodeRef)
720-
self.assertLen(nnx.to_flat_state(t1.states[0]), 1)
720+
self.assertLen(nnx.to_flat_state(t1.states[0]), 2)
721721
self.assertLen(nnx.to_flat_state(t2.states[0]), 0)
722722

723723
@jax.jit
@@ -744,7 +744,7 @@ def f(pure_tree):
744744
assert isinstance(t2, nnx.NodeStates)
745745
self.assertIsInstance(t1.graphdef.nodes[0], nnx.graph.NodeDef)
746746
self.assertIsInstance(t2.graphdef.nodes[0], nnx.graph.NodeRef)
747-
self.assertLen(nnx.to_flat_state(t1.states[0]), 1)
747+
self.assertLen(nnx.to_flat_state(t1.states[0]), 2)
748748
self.assertLen(nnx.to_flat_state(t2.states[0]), 0)
749749

750750
return pure_tree2
@@ -762,6 +762,20 @@ def f(pure_tree):
762762
self.assertEqual(m.b[...], 1) # type: ignore
763763
self.assertEqual(impure_tree2[1], 1)
764764

765+
def test_graph_flatten_with_data_wrapper(self):
766+
class Foo(nnx.Pytree):
767+
def __init__(self, data, static):
768+
self.data = nnx.data(data)
769+
self.static = nnx.static(static)
770+
771+
tree = Foo(1, 2)
772+
state = nnx.state(tree)
773+
774+
self.assertIn('data', state)
775+
self.assertIsInstance(state['data'], int)
776+
self.assertEqual(state['data'], 1)
777+
self.assertNotIn('static', state)
778+
765779
def test_to_tree_consistent_prefix(self):
766780
m = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
767781
impure_tree = (m, 1, {'b': m})

0 commit comments

Comments
 (0)