Skip to content

Commit f9a6055

Browse files
author
Flax Authors
committed
Merge pull request #4869 from google:update-for-dict
PiperOrigin-RevId: 792381132
2 parents 220910f + b5f6768 commit f9a6055

File tree

2 files changed

+99
-43
lines changed

2 files changed

+99
-43
lines changed

flax/nnx/graph.py

Lines changed: 60 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,8 @@ class GraphNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
196196
@dataclasses.dataclass(frozen=True, slots=True)
197197
class PytreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
198198
unflatten: tp.Callable[[tp.Sequence[tuple[Key, Leaf]], AuxData], Node]
199+
set_key: tp.Callable[[Node, Key, Leaf], None] | None
200+
pop_key: tp.Callable[[Node, Key], Leaf] | None
199201

200202

201203
NodeImpl = tp.Union[
@@ -234,12 +236,19 @@ def register_pytree_node_type(
234236
type: type,
235237
flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[Key, Leaf]], AuxData]],
236238
unflatten: tp.Callable[[tp.Sequence[tuple[Key, Leaf]], AuxData], Node],
239+
*,
240+
set_key: tp.Callable[[Node, Key, Leaf], None] | None = None,
241+
pop_key: tp.Callable[[Node, Key], Leaf] | None = None,
237242
):
238243
if type in PYTREE_REGISTRY:
239244
raise ValueError(f'Node type {type} is already registered.')
240245

241246
PYTREE_REGISTRY[type] = PytreeNodeImpl(
242-
type=type, flatten=flatten, unflatten=unflatten
247+
type=type,
248+
flatten=flatten,
249+
unflatten=unflatten,
250+
set_key=set_key,
251+
pop_key=pop_key,
243252
)
244253

245254

@@ -1146,16 +1155,16 @@ def _graph_unflatten(
11461155
) -> Node:
11471156
"""Recursive helper for graph_unflatten.
11481157
1149-
Args:
1150-
nodedef: A GraphDef instance or an index to a node in the cache.
1151-
state: A mapping from attribute names to variables or subgraphs.
1152-
index_to_ref: A mapping from indexes to nodes that have been traversed.
1153-
If a node is already in the cache, it won't be traversed again.
1154-
f0f6619b-dde6-4466-b699-61c47f268d6b index_ref_cache: A mapping from indexes to existing nodes that can be reused.
1155-
When an reference is reused, ``GraphNodeImpl.clear`` is called to leave the
1156-
object in an empty state and then filled by the unflatten process, as a result
1157-
existing graph nodes are mutated to have the new content/topology
1158-
specified by the nodedef.
1158+
Args:
1159+
nodedef: A GraphDef instance or an index to a node in the cache.
1160+
state: A mapping from attribute names to variables or subgraphs.
1161+
index_ref: A mapping from indexes to nodes that have been traversed.
1162+
If a node is already in the cache, it won't be traversed again.
1163+
outer_index_outer_ref: A mapping from indexes to existing nodes that can be reused.
1164+
When an reference is reused, ``GraphNodeImpl.clear`` is called to leave the
1165+
object in an empty state and then filled by the unflatten process, as a result
1166+
existing graph nodes are mutated to have the new content/topology
1167+
specified by the nodedef.
11591168
"""
11601169

11611170
def get_mutable_array(array_refdef: ArrayRefDef, leaf):
@@ -1168,13 +1177,9 @@ def get_mutable_array(array_refdef: ArrayRefDef, leaf):
11681177
# if array ref exists, update it
11691178
array_ref = outer_index_outer_ref[array_refdef.outer_index]
11701179
if not variablelib.is_array_ref(array_ref):
1171-
raise RuntimeError(
1172-
f'Expected a ArrayRef type but got {array_ref}.'
1173-
)
1180+
raise RuntimeError(f'Expected a ArrayRef type but got {array_ref}.')
11741181
if type(leaf) is not NoUpdate:
1175-
raise RuntimeError(
1176-
f'Expected a no update for ArrayRef but got {leaf}.'
1177-
)
1182+
raise RuntimeError(f'Expected a no update for ArrayRef but got {leaf}.')
11781183
elif type(leaf) in (NoUpdate, Repeated):
11791184
raise ValueError(
11801185
f"Expected a ArrayRefOutput type but got '{leaf.value}.'"
@@ -1206,9 +1211,7 @@ def get_mutable_array(array_refdef: ArrayRefDef, leaf):
12061211
if isinstance(value, Variable):
12071212
value = value.copy() if copy_variables else value
12081213
inner_value = value.raw_value
1209-
array_ref = get_mutable_array(
1210-
variabledef.array_refdef, inner_value
1211-
)
1214+
array_ref = get_mutable_array(variabledef.array_refdef, inner_value)
12121215
if array_ref is not inner_value:
12131216
value.raw_value = array_ref
12141217
else:
@@ -1268,8 +1271,7 @@ def _get_children() -> list[tuple[Key, tp.Any]]:
12681271
elif type(value) is MutableArrayAttr:
12691272
array_refdef = next(node_iter)
12701273
assert (
1271-
type(array_refdef) is ArrayRefDef
1272-
or type(array_refdef) is NodeRef
1274+
type(array_refdef) is ArrayRefDef or type(array_refdef) is NodeRef
12731275
)
12741276
if type(array_refdef) is NodeRef:
12751277
array_ref = index_ref[array_refdef.index]
@@ -1392,7 +1394,7 @@ def _graph_pop(
13921394

13931395
for state, predicate in zip(flat_states, predicates):
13941396
if predicate(node_path, value):
1395-
if isinstance(node_impl, PytreeNodeImpl):
1397+
if node_impl.pop_key is None:
13961398
raise ValueError(
13971399
f'Cannot pop key {name!r} from node of type {type(node).__name__}'
13981400
)
@@ -1441,7 +1443,7 @@ def _update_variable(node: Variable, value):
14411443
for key, value in state.items():
14421444
# case 1: new state is being added
14431445
if key not in node_dict:
1444-
if isinstance(node_impl, PytreeNodeImpl):
1446+
if node_impl.set_key is None:
14451447
raise ValueError(
14461448
f'Cannot set key {key!r} on immutable node of '
14471449
f'type {type(node).__name__}'
@@ -1460,22 +1462,16 @@ def _update_variable(node: Variable, value):
14601462
if is_node_leaf(value):
14611463
raise ValueError(f'Expected a subgraph for {key!r}, but got: {value!r}')
14621464
_graph_update_dynamic(current_value, value)
1463-
else:
1464-
if isinstance(current_value, jax.Array | np.ndarray):
1465-
if isinstance(node_impl, PytreeNodeImpl):
1466-
raise ValueError(
1467-
f'Cannot set key {key!r} on immutable node of '
1468-
f'type {type(node).__name__}'
1469-
)
1470-
node_impl.set_key(node, key, value)
1471-
continue
1472-
elif not isinstance(current_value, Variable):
1473-
# case 3: state leaf is being updated
1474-
raise ValueError(
1475-
f'Trying to update a non-Variable attribute {key!r} with a Variable: '
1476-
f'{value!r}'
1477-
)
1465+
elif isinstance(current_value, Variable):
14781466
_update_variable(current_value, value)
1467+
elif node_impl.set_key is not None:
1468+
node_impl.set_key(node, key, value)
1469+
else:
1470+
raise ValueError(
1471+
f'Cannot set key {key!r} on immutable node of '
1472+
f'type {type(node).__name__}'
1473+
)
1474+
14791475

14801476

14811477
# --------------------------------------------------------
@@ -2621,9 +2617,7 @@ def find_duplicates(
26212617

26222618

26232619
def _mutable_like(path, x):
2624-
return (
2625-
isinstance(x, Variable) and x.has_ref
2626-
) or variablelib.is_array_ref(x)
2620+
return (isinstance(x, Variable) and x.has_ref) or variablelib.is_array_ref(x)
26272621

26282622

26292623
def to_arrays(
@@ -3021,6 +3015,8 @@ def _unflatten_pytree(
30213015
type=GenericPytree,
30223016
flatten=_flatten_pytree,
30233017
unflatten=_unflatten_pytree, # type: ignore
3018+
set_key=None,
3019+
pop_key=None,
30243020
)
30253021

30263022
# common pytrees
@@ -3036,15 +3032,36 @@ def _unflatten_pytree(
30363032
flatten=lambda x: (list(enumerate(x)), None),
30373033
unflatten=lambda nodes, _: tuple(value for _, value in nodes), # type: ignore
30383034
)
3035+
3036+
3037+
def _mutable_mapping_set_key(x: tp.MutableMapping[Key, tp.Any], key: Key, value: tp.Any):
3038+
x[key] = value
3039+
3040+
3041+
def _mutable_mapping_pop_key(x: tp.MutableMapping[Key, tp.Any], key: Key):
3042+
x.pop(key)
3043+
3044+
30393045
# dict
30403046
register_pytree_node_type(
30413047
dict,
30423048
flatten=lambda x: (sorted(x.items()), None),
3043-
unflatten=lambda nodes, _: {key: value for key, value in nodes}, # type: ignore
3049+
unflatten=lambda nodes, _: dict(nodes), # type: ignore
3050+
set_key=_mutable_mapping_set_key,
3051+
pop_key=_mutable_mapping_pop_key,
3052+
)
3053+
# State
3054+
register_pytree_node_type(
3055+
State,
3056+
flatten=lambda x: (sorted(x.raw_mapping.items()), None),
3057+
unflatten=lambda nodes, _: State(nodes), # type: ignore
3058+
set_key=_mutable_mapping_set_key,
3059+
pop_key=_mutable_mapping_pop_key,
30443060
)
30453061
# None
30463062
register_pytree_node_type(
30473063
type(None),
30483064
flatten=lambda x: ([], None),
30493065
unflatten=lambda _, __: None, # type: ignore
30503066
)
3067+

tests/nnx/graph_utils_test.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,6 +1079,45 @@ def __init__(self):
10791079
self.assertIn('ls', m._pytree__nodes)
10801080
self.assertLen(jax.tree.leaves(m), 1)
10811081

1082+
def test_update_dict(self):
1083+
node = {
1084+
'a': {
1085+
'b': 1,
1086+
'c': nnx.Param(2),
1087+
'd': 3,
1088+
},
1089+
}
1090+
1091+
updates = {
1092+
'a': {
1093+
'b': 4,
1094+
'c': 10,
1095+
},
1096+
}
1097+
1098+
nnx.update(node, updates)
1099+
1100+
self.assertEqual(node['a']['b'], 4)
1101+
self.assertEqual(node['a']['c'].value, 10)
1102+
self.assertEqual(node['a']['d'], 3)
1103+
1104+
def test_pop_dict(self):
1105+
node = {
1106+
'a': {
1107+
'b': jnp.array(1),
1108+
'c': nnx.Param(2),
1109+
'd': jnp.array(3.0),
1110+
},
1111+
}
1112+
lt_2 = lambda _, x: x < 2
1113+
popped = nnx.pop(node, (nnx.Param, lt_2))
1114+
1115+
self.assertEqual(popped['a']['b'], 1)
1116+
self.assertEqual(popped['a']['c'].value, 2)
1117+
self.assertEqual(node['a']['d'], 3.0)
1118+
self.assertLen(jax.tree.leaves(node), 1)
1119+
self.assertLen(jax.tree.leaves(popped), 2)
1120+
10821121
class SimpleModule(nnx.Module):
10831122
pass
10841123

0 commit comments

Comments
 (0)