Skip to content

Commit d9dee2d

Browse files
Cristian GarciaFlax Authors
authored andcommitted
enable mlp test
PiperOrigin-RevId: 866211786
1 parent 2e1f5ad commit d9dee2d

File tree

2 files changed

+134
-29
lines changed

2 files changed

+134
-29
lines changed

flax/nnx/pytreelib.py

Lines changed: 63 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,6 @@ def _pytree_meta_construct(cls, self, *args, **kwargs):
406406

407407
def _graph_node_meta_call(cls: tp.Type[P], *args, **kwargs) -> P:
408408
node = cls.__new__(cls, *args, **kwargs)
409-
vars_obj = vars(node)
410409
object.__setattr__(node, '_pytree__state', PytreeState())
411410
object.__setattr__(node, '_pytree__nodes', cls._pytree__nodes)
412411
cls._pytree_meta_construct(node, *args, **kwargs)
@@ -498,6 +497,11 @@ def __init_subclass__(
498497
**kwargs,
499498
) -> None:
500499
super().__init_subclass__(**kwargs)
500+
if slots := getattr(cls, '__slots__', ()):
501+
raise TypeError(
502+
'Pytree currently does not support __slots__, '
503+
f"found __slots__={slots} in '{cls.__name__}'."
504+
)
501505
cls._pytree__is_pytree = pytree
502506

503507
graph.register_graph_node_type(
@@ -874,22 +878,30 @@ def _pytree__flatten_with_paths(self):
874878
else:
875879
key_fn = None
876880
node_attributes = self._pytree__nodes
877-
node_names: list[str] = []
881+
node_keys: list[str | int] = []
878882
node_attrs: list[tuple[tp.Any, tp.Any]] = []
879-
static_attrs: list[tuple[str, tp.Any]] = []
880-
for name, value in sorted(obj_items, key=key_fn):
881-
if name in node_attributes and node_attributes[name]:
882-
node_names.append(name)
883+
static_keys: list[str | int] = []
884+
static_attrs: list[tp.Any] = []
885+
for key, value in sorted(obj_items, key=key_fn):
886+
# get string representation of the key because
887+
# node_attributes keys are strings
888+
key_str = _get_str(key)
889+
if key_str in node_attributes and node_attributes[key_str]:
890+
node_keys.append(key)
883891
node_attrs.append((
884-
jax.tree_util.GetAttrKey(name)
885-
if isinstance(name, str)
886-
else jax.tree_util.SequenceKey(name),
892+
jax.tree_util.GetAttrKey(key)
893+
if isinstance(key, str)
894+
else jax.tree_util.SequenceKey(key),
887895
value,
888896
))
889897
else:
890-
static_attrs.append((name, value))
898+
static_keys.append(key)
899+
static_attrs.append(value)
891900

892-
return node_attrs, (tuple(node_names), tuple(static_attrs))
901+
return (
902+
node_attrs,
903+
(tuple(node_keys), tuple(static_keys), tuple(static_attrs)),
904+
)
893905

894906
def _pytree__flatten(self):
895907
obj_items = vars(self).items()
@@ -899,34 +911,43 @@ def _pytree__flatten(self):
899911
else:
900912
key_fn = None
901913
node_attributes = self._pytree__nodes
902-
node_names: list[str] = []
914+
node_keys: list[str | int] = []
903915
node_attrs: list[tp.Any] = []
904-
static_attrs: list[tuple[str, tp.Any]] = []
905-
for name, value in sorted(obj_items, key=key_fn):
906-
if name in node_attributes and node_attributes[name]:
907-
node_names.append(name)
916+
static_keys: list[str | int] = []
917+
static_attrs: list[tp.Any] = []
918+
for key, value in sorted(obj_items, key=key_fn):
919+
# get string representation of the key because
920+
# node_attributes keys are strings
921+
key_str = _get_str(key)
922+
if key_str in node_attributes and node_attributes[key_str]:
923+
node_keys.append(key)
908924
node_attrs.append(value)
909925
else:
910-
static_attrs.append((name, value))
926+
static_keys.append(key)
927+
static_attrs.append(value)
911928

912-
return node_attrs, (tuple(node_names), tuple(static_attrs))
929+
return (
930+
node_attrs,
931+
(tuple(node_keys), tuple(static_keys), tuple(static_attrs)),
932+
)
913933

914934
@classmethod
915935
def _pytree__unflatten(
916936
cls,
917-
static: tuple[tuple[str, ...], tuple[tuple[str, tp.Any], ...]],
937+
static: tuple[tuple[str | int, ...], tuple[str | int, ...], tuple[tp.Any, ...]],
918938
node_attrs: tp.Iterable[tp.Any],
919939
):
920-
node_names, static_attrs = static
940+
node_keys, static_keys, static_attrs = static
921941
obj = object.__new__(cls)
922-
vars_obj = vars(obj)
923942
if cls._pytree__has_int_keys:
924-
node_names = tuple(
925-
str(name) if isinstance(name, int) else name for name in node_names
926-
)
927-
for name, value in zip(node_names, node_attrs, strict=True):
943+
node_keys_iter = map(_get_str, node_keys)
944+
static_keys_iter = map(_get_str, static_keys)
945+
else:
946+
node_keys_iter = node_keys
947+
static_keys_iter = static_keys
948+
for name, value in zip(node_keys_iter, node_attrs, strict=True):
928949
object.__setattr__(obj, name, value)
929-
for name, value in static_attrs:
950+
for name, value in zip(static_keys_iter, static_attrs, strict=True):
930951
object.__setattr__(obj, name, value)
931952
return obj
932953

@@ -946,7 +967,16 @@ def _graph_node_flatten(self):
946967
def _graph_node_set_key(self, key, value: tp.Any):
947968
if self._pytree__has_int_keys and isinstance(key, int):
948969
key = str(key)
949-
setattr(self, key, value)
970+
if not isinstance(key, str):
971+
raise KeyError(f'Invalid key: {key!r}')
972+
elif (
973+
hasattr(self, key)
974+
and isinstance(variable := getattr(self, key), Variable)
975+
and isinstance(value, Variable)
976+
):
977+
variable.update_from_state(value)
978+
else:
979+
setattr(self, key, value)
950980

951981
def _graph_node_pop_key(self, key):
952982
if self._pytree__has_int_keys and isinstance(key, int):
@@ -978,7 +1008,8 @@ def _graph_node_init(self, attributes: tp.Iterable[tuple[str, tp.Any]]):
9781008
(str(name) if isinstance(name, int) else name, value)
9791009
for name, value in attributes
9801010
)
981-
vars(self).update(attributes)
1011+
for name, value in attributes:
1012+
object.__setattr__(self, name, value)
9821013

9831014
if tp.TYPE_CHECKING:
9841015
def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Any: ...
@@ -1000,4 +1031,7 @@ def _maybe_int(x):
10001031
try:
10011032
return int(x)
10021033
except (ValueError, TypeError):
1003-
return x
1034+
return x
1035+
1036+
def _get_str(x):
1037+
return x if isinstance(x, str) else str(x)

tests/nnx/helpers_test.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,77 @@ def test_list_mutable_sequence(self):
159159

160160
self.assertEqual(l[1:3], [6, 7])
161161

162+
def test_list_fori_loop(self):
163+
class Foo(nnx.Module):
164+
def __init__(self):
165+
self.layers = nnx.List([
166+
nnx.Linear(1, 1, rngs=nnx.Rngs(0)),
167+
nnx.Linear(1, 1, rngs=nnx.Rngs(0)),
168+
])
169+
170+
def batch_loop_body(i, carry):
171+
return carry
172+
173+
net = Foo()
174+
jax.lax.fori_loop(0, 2, batch_loop_body, net)
175+
176+
def test_list_pytree_default_behavior(self):
177+
ls = nnx.List([jnp.array(1), jnp.array(2), jnp.array(3)])
178+
leaves = jax.tree_util.tree_leaves(ls)
179+
self.assertLen(leaves, 3)
180+
np.testing.assert_array_equal(leaves[0], jnp.array(1))
181+
np.testing.assert_array_equal(leaves[1], jnp.array(2))
182+
np.testing.assert_array_equal(leaves[2], jnp.array(3))
183+
184+
def test_list_pytree_static_elements(self):
185+
ls = nnx.List([nnx.static(10), nnx.static(20), nnx.static(30)])
186+
leaves = jax.tree_util.tree_leaves(ls)
187+
self.assertEmpty(leaves)
188+
189+
def test_list_pytree_data_elements(self):
190+
ls = nnx.List([nnx.data(1), nnx.data(2), nnx.data(3)])
191+
leaves = jax.tree_util.tree_leaves(ls)
192+
self.assertLen(leaves, 3)
193+
self.assertEqual(leaves[0], 1)
194+
self.assertEqual(leaves[1], 2)
195+
self.assertEqual(leaves[2], 3)
196+
197+
def test_list_pytree_mixed_static_data(self):
198+
ls = nnx.List([
199+
nnx.data(jnp.array(1)),
200+
nnx.static(100),
201+
nnx.data(jnp.array(2)),
202+
nnx.static(200),
203+
])
204+
leaves = jax.tree_util.tree_leaves(ls)
205+
self.assertLen(leaves, 2)
206+
np.testing.assert_array_equal(leaves[0], jnp.array(1))
207+
np.testing.assert_array_equal(leaves[1], jnp.array(2))
208+
209+
def test_list_pytree_flatten_unflatten(self):
210+
ls = nnx.List([nnx.data(10), nnx.static('hello'), nnx.data(20)])
211+
leaves, treedef = jax.tree_util.tree_flatten(ls)
212+
self.assertLen(leaves, 2)
213+
self.assertEqual(leaves[0], 10)
214+
self.assertEqual(leaves[1], 20)
215+
216+
new_leaves = [x * 2 for x in leaves]
217+
new_ls = jax.tree_util.tree_unflatten(treedef, new_leaves)
218+
self.assertEqual(new_ls[0], 20)
219+
self.assertEqual(new_ls[1], 'hello')
220+
self.assertEqual(new_ls[2], 40)
221+
222+
def test_list_pytree_jit(self):
223+
ls = nnx.List([nnx.data(jnp.array(1.0)), nnx.static(999)])
224+
225+
@jax.jit
226+
def double(ls):
227+
return jax.tree.map(lambda x: x * 2, ls)
228+
229+
result = double(ls)
230+
np.testing.assert_array_equal(result[0], jnp.array(2.0))
231+
self.assertEqual(result[1], 999)
232+
162233

163234
if __name__ == '__main__':
164235
absltest.main()

0 commit comments

Comments
 (0)