Skip to content

Commit 2ab6208

Browse files
Cristian GarciaFlax Authors
authored andcommitted
Improve Pytree flatten/unflatten
* Correctly converts int static keys to strs during unflatten. * No longer uses setattr during unflatten. PiperOrigin-RevId: 866538022
1 parent 2e1f5ad commit 2ab6208

File tree

3 files changed

+134
-38
lines changed

3 files changed

+134
-38
lines changed

flax/nnx/pytreelib.py

Lines changed: 61 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,6 @@ def _pytree_meta_construct(cls, self, *args, **kwargs):
406406

407407
def _graph_node_meta_call(cls: tp.Type[P], *args, **kwargs) -> P:
408408
node = cls.__new__(cls, *args, **kwargs)
409-
vars_obj = vars(node)
410409
object.__setattr__(node, '_pytree__state', PytreeState())
411410
object.__setattr__(node, '_pytree__nodes', cls._pytree__nodes)
412411
cls._pytree_meta_construct(node, *args, **kwargs)
@@ -498,6 +497,11 @@ def __init_subclass__(
498497
**kwargs,
499498
) -> None:
500499
super().__init_subclass__(**kwargs)
500+
if slots := getattr(cls, '__slots__', ()):
501+
raise TypeError(
502+
'Pytree currently does not support __slots__, '
503+
f"found __slots__={slots} in '{cls.__name__}'."
504+
)
501505
cls._pytree__is_pytree = pytree
502506

503507
graph.register_graph_node_type(
@@ -874,22 +878,30 @@ def _pytree__flatten_with_paths(self):
874878
else:
875879
key_fn = None
876880
node_attributes = self._pytree__nodes
877-
node_names: list[str] = []
881+
node_keys: list[str | int] = []
878882
node_attrs: list[tuple[tp.Any, tp.Any]] = []
879-
static_attrs: list[tuple[str, tp.Any]] = []
880-
for name, value in sorted(obj_items, key=key_fn):
881-
if name in node_attributes and node_attributes[name]:
882-
node_names.append(name)
883+
static_keys: list[str | int] = []
884+
static_attrs: list[tp.Any] = []
885+
for key, value in sorted(obj_items, key=key_fn):
886+
# get string representation of the key because
887+
# node_attributes keys are strings
888+
key_str = _get_str(key)
889+
if key_str in node_attributes and node_attributes[key_str]:
890+
node_keys.append(key)
883891
node_attrs.append((
884-
jax.tree_util.GetAttrKey(name)
885-
if isinstance(name, str)
886-
else jax.tree_util.SequenceKey(name),
892+
jax.tree_util.GetAttrKey(key)
893+
if isinstance(key, str)
894+
else jax.tree_util.SequenceKey(key),
887895
value,
888896
))
889897
else:
890-
static_attrs.append((name, value))
898+
static_keys.append(key)
899+
static_attrs.append(value)
891900

892-
return node_attrs, (tuple(node_names), tuple(static_attrs))
901+
return (
902+
node_attrs,
903+
(tuple(node_keys), tuple(static_keys), tuple(static_attrs)),
904+
)
893905

894906
def _pytree__flatten(self):
895907
obj_items = vars(self).items()
@@ -899,35 +911,38 @@ def _pytree__flatten(self):
899911
else:
900912
key_fn = None
901913
node_attributes = self._pytree__nodes
902-
node_names: list[str] = []
914+
node_keys: list[str | int] = []
903915
node_attrs: list[tp.Any] = []
904-
static_attrs: list[tuple[str, tp.Any]] = []
905-
for name, value in sorted(obj_items, key=key_fn):
906-
if name in node_attributes and node_attributes[name]:
907-
node_names.append(name)
916+
static_keys: list[str | int] = []
917+
static_attrs: list[tp.Any] = []
918+
for key, value in sorted(obj_items, key=key_fn):
919+
# get string representation of the key because
920+
# node_attributes keys are strings
921+
key_str = _get_str(key)
922+
if key_str in node_attributes and node_attributes[key_str]:
923+
node_keys.append(key)
908924
node_attrs.append(value)
909925
else:
910-
static_attrs.append((name, value))
926+
static_keys.append(key)
927+
static_attrs.append(value)
911928

912-
return node_attrs, (tuple(node_names), tuple(static_attrs))
929+
return (
930+
node_attrs,
931+
(tuple(node_keys), tuple(static_keys), tuple(static_attrs)),
932+
)
913933

914934
@classmethod
915935
def _pytree__unflatten(
916936
cls,
917-
static: tuple[tuple[str, ...], tuple[tuple[str, tp.Any], ...]],
937+
static: tuple[tp.Iterable[str | int], tp.Iterable[str | int], tp.Iterable[tp.Any]],
918938
node_attrs: tp.Iterable[tp.Any],
919939
):
920-
node_names, static_attrs = static
940+
node_keys, static_keys, static_attrs = static
921941
obj = object.__new__(cls)
922-
vars_obj = vars(obj)
923-
if cls._pytree__has_int_keys:
924-
node_names = tuple(
925-
str(name) if isinstance(name, int) else name for name in node_names
926-
)
927-
for name, value in zip(node_names, node_attrs, strict=True):
928-
object.__setattr__(obj, name, value)
929-
for name, value in static_attrs:
930-
object.__setattr__(obj, name, value)
942+
for name, value in zip(node_keys, node_attrs, strict=True):
943+
object.__setattr__(obj, _get_str(name), value)
944+
for name, value in zip(static_keys, static_attrs, strict=True):
945+
object.__setattr__(obj, _get_str(name), value)
931946
return obj
932947

933948
# -------------------------
@@ -946,7 +961,16 @@ def _graph_node_flatten(self):
946961
def _graph_node_set_key(self, key, value: tp.Any):
947962
if self._pytree__has_int_keys and isinstance(key, int):
948963
key = str(key)
949-
setattr(self, key, value)
964+
if not isinstance(key, str):
965+
raise KeyError(f'Invalid key: {key!r}')
966+
elif (
967+
hasattr(self, key)
968+
and isinstance(variable := getattr(self, key), Variable)
969+
and isinstance(value, Variable)
970+
):
971+
variable.update_from_state(value)
972+
else:
973+
setattr(self, key, value)
950974

951975
def _graph_node_pop_key(self, key):
952976
if self._pytree__has_int_keys and isinstance(key, int):
@@ -972,13 +996,9 @@ def _graph_node_create_empty(node_type: tp.Type[P]) -> P:
972996
def _graph_node_clear(self):
973997
vars(self).clear()
974998

975-
def _graph_node_init(self, attributes: tp.Iterable[tuple[str, tp.Any]]):
976-
if self._pytree__has_int_keys:
977-
attributes = (
978-
(str(name) if isinstance(name, int) else name, value)
979-
for name, value in attributes
980-
)
981-
vars(self).update(attributes)
999+
def _graph_node_init(self, attributes: tp.Iterable[tuple[str | int, tp.Any]]):
1000+
for name, value in attributes:
1001+
object.__setattr__(self, _get_str(name), value)
9821002

9831003
if tp.TYPE_CHECKING:
9841004
def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Any: ...
@@ -1000,4 +1020,7 @@ def _maybe_int(x):
10001020
try:
10011021
return int(x)
10021022
except (ValueError, TypeError):
1003-
return x
1023+
return x
1024+
1025+
def _get_str(x):
1026+
return x if isinstance(x, str) else str(x)

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,8 @@ filterwarnings = [
202202
"ignore:.*[Data|Static]' is deprecated, please replace.*:DeprecationWarning",
203203
# DeprecationWarning: Implicit conversion of an array to a dtype is deprecated; rather than dtype=arr use dtype=arr.dtype.
204204
"ignore:.*Implicit conversion of an array to a dtype is deprecated; rather than dtype=arr use dtype=arr.dtype.*:DeprecationWarning",
205+
# DeprecationWarning: Setting `jax_pmap_shmap_merge` is deprecated in JAX v0.9.0 and will be removed in JAX v0.10.0
206+
"ignore:.*Setting `jax_pmap_shmap_merge` is deprecated in JAX.*:DeprecationWarning",
205207
]
206208

207209
[tool.coverage.report]

tests/nnx/helpers_test.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,77 @@ def test_list_mutable_sequence(self):
159159

160160
self.assertEqual(l[1:3], [6, 7])
161161

162+
def test_list_fori_loop(self):
163+
class Foo(nnx.Module):
164+
def __init__(self):
165+
self.layers = nnx.List([
166+
nnx.Linear(1, 1, rngs=nnx.Rngs(0)),
167+
nnx.Linear(1, 1, rngs=nnx.Rngs(0)),
168+
])
169+
170+
def batch_loop_body(i, carry):
171+
return carry
172+
173+
net = Foo()
174+
jax.lax.fori_loop(0, 2, batch_loop_body, net)
175+
176+
def test_list_pytree_default_behavior(self):
177+
ls = nnx.List([jnp.array(1), jnp.array(2), jnp.array(3)])
178+
leaves = jax.tree_util.tree_leaves(ls)
179+
self.assertLen(leaves, 3)
180+
np.testing.assert_array_equal(leaves[0], jnp.array(1))
181+
np.testing.assert_array_equal(leaves[1], jnp.array(2))
182+
np.testing.assert_array_equal(leaves[2], jnp.array(3))
183+
184+
def test_list_pytree_static_elements(self):
185+
ls = nnx.List([nnx.static(10), nnx.static(20), nnx.static(30)])
186+
leaves = jax.tree_util.tree_leaves(ls)
187+
self.assertEmpty(leaves)
188+
189+
def test_list_pytree_data_elements(self):
190+
ls = nnx.List([nnx.data(1), nnx.data(2), nnx.data(3)])
191+
leaves = jax.tree_util.tree_leaves(ls)
192+
self.assertLen(leaves, 3)
193+
self.assertEqual(leaves[0], 1)
194+
self.assertEqual(leaves[1], 2)
195+
self.assertEqual(leaves[2], 3)
196+
197+
def test_list_pytree_mixed_static_data(self):
198+
ls = nnx.List([
199+
nnx.data(jnp.array(1)),
200+
nnx.static(100),
201+
nnx.data(jnp.array(2)),
202+
nnx.static(200),
203+
])
204+
leaves = jax.tree_util.tree_leaves(ls)
205+
self.assertLen(leaves, 2)
206+
np.testing.assert_array_equal(leaves[0], jnp.array(1))
207+
np.testing.assert_array_equal(leaves[1], jnp.array(2))
208+
209+
def test_list_pytree_flatten_unflatten(self):
210+
ls = nnx.List([nnx.data(10), nnx.static('hello'), nnx.data(20)])
211+
leaves, treedef = jax.tree_util.tree_flatten(ls)
212+
self.assertLen(leaves, 2)
213+
self.assertEqual(leaves[0], 10)
214+
self.assertEqual(leaves[1], 20)
215+
216+
new_leaves = [x * 2 for x in leaves]
217+
new_ls = jax.tree_util.tree_unflatten(treedef, new_leaves)
218+
self.assertEqual(new_ls[0], 20)
219+
self.assertEqual(new_ls[1], 'hello')
220+
self.assertEqual(new_ls[2], 40)
221+
222+
def test_list_pytree_jit(self):
223+
ls = nnx.List([nnx.data(jnp.array(1.0)), nnx.static(999)])
224+
225+
@jax.jit
226+
def double(ls):
227+
return jax.tree.map(lambda x: x * 2, ls)
228+
229+
result = double(ls)
230+
np.testing.assert_array_equal(result[0], jnp.array(2.0))
231+
self.assertEqual(result[1], 999)
232+
162233

163234
if __name__ == '__main__':
164235
absltest.main()

0 commit comments

Comments
 (0)