Skip to content

Commit aa7ab77

Browse files
committed
Make lists return true for is_data
1 parent f075798 commit aa7ab77

File tree

7 files changed

+14
-32
lines changed

7 files changed

+14
-32
lines changed

examples/gemma/transformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ def __init__(
537537
dtype=config.dtype,
538538
rngs=rngs,
539539
)
540-
self.layers = nnx.data([
540+
self.layers = [
541541
modules.Block(
542542
config=config,
543543
attn_type=attn_type,
@@ -547,7 +547,7 @@ def __init__(
547547
for _, attn_type in zip(
548548
range(config.num_layers), config.attention_types
549549
)
550-
])
550+
]
551551
self.final_norm = layers.RMSNorm(
552552
config.embed_dim,
553553
scale_init=modules.maybe_with_partitioning(

examples/nnx_toy_examples/hijax_demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def create_block(rngs, /):
141141
# self.blocks = nnx.stateful(create_block(rngs.fork(split=num_blocks)))
142142
self.blocks = create_block(rngs.fork(split=num_blocks))
143143
else:
144-
self.blocks = nnx.data([Block(dhidden, dhidden, rngs=rngs) for i in range(num_blocks)])
144+
self.blocks = [Block(dhidden, dhidden, rngs=rngs) for i in range(num_blocks)]
145145

146146
def __call__(self, x: jax.Array, *, rngs: nnx.Rngs | None = None):
147147
self.count[...] += 1

flax/nnx/helpers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from flax.nnx.module import Module
2626
from flax.nnx.proxy_caller import ApplyCaller
2727
from flax.nnx.rnglib import Rngs
28-
from flax.nnx.pytreelib import data
2928
from flax.nnx.statelib import State
3029
from flax.training.train_state import struct
3130

@@ -120,7 +119,7 @@ def __init__(self, *fns: tp.Callable[..., tp.Any]):
120119
Args:
121120
*fns: A sequence of callables to apply.
122121
"""
123-
self.layers = data(list(fns))
122+
self.layers = list(fns)
124123

125124
def __call__(self, *args, rngs: tp.Optional[Rngs] = None, **kwargs) -> tp.Any:
126125
if len(self.layers) == 0:

flax/nnx/pytreelib.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def is_data(value: tp.Any, /) -> bool:
173173
... # ------ STATIC ------------
174174
>>> assert not nnx.is_data( 'hello' ) # strings, arbitrary objects
175175
>>> assert not nnx.is_data( 42 ) # int, float, bool, complex, etc.
176-
>>> assert not nnx.is_data( [1, 2.0, 3j, jnp.array(1)] ) # list, dict, tuple, pytrees
176+
>>> assert nnx.is_data( [1, 2.0, 3j, jnp.array(1)] ) # list, dict, tuple, pytrees
177177
178178
179179
Args:
@@ -183,7 +183,8 @@ def is_data(value: tp.Any, /) -> bool:
183183
A string representing the attribute status.
184184
"""
185185
return (
186-
graph.is_node_leaf(value)
186+
isinstance(value, list)
187+
or graph.is_node_leaf(value)
187188
or graph.is_graph_node(value)
188189
or type(value) in DATA_REGISTRY
189190
)

tests/nnx/graph_utils_test.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
import jax.numpy as jnp
2525

2626

27-
28-
2927
class StatefulLinear(nnx.Module):
3028
def __init__(self, din, dout, rngs):
3129
self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout)))
@@ -1007,17 +1005,6 @@ def f(m):
10071005

10081006
self.assertEqual(m.a, 2)
10091007

1010-
def test_data_after_init(self):
1011-
test = self
1012-
class Foo(nnx.Module):
1013-
def __init__(self):
1014-
self.ls = []
1015-
self.ls.append(jnp.array(1))
1016-
1017-
with self.assertRaisesRegex(
1018-
ValueError, 'Found unexpected Arrays on value of type'
1019-
):
1020-
m = Foo()
10211008

10221009
def test_update_dict(self):
10231010
node = {

tests/nnx/module_test.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828

2929
A = TypeVar('A')
3030

31-
3231
class PytreeTest(absltest.TestCase):
3332
def test_pytree(self):
3433
class Foo(nnx.Pytree):
@@ -1001,13 +1000,10 @@ def __call__(self, x, *, rngs: nnx.Rngs):
10011000

10021001
assert isinstance(y, jax.Array)
10031002

1004-
def test_modules_iterator(self):
1005-
class Foo(nnx.Module):
1006-
def __init__(self, *, rngs: nnx.Rngs):
1007-
self.submodules = nnx.data([
1003+
self.submodules = [
10081004
{'a': nnx.Linear(1, 1, rngs=rngs)},
10091005
{'b': nnx.Conv(1, 1, 1, rngs=rngs)},
1010-
])
1006+
]
10111007
self.linear = nnx.Linear(1, 1, rngs=rngs)
10121008
self.dropout = nnx.Dropout(0.5, rngs=rngs)
10131009

@@ -1030,10 +1026,10 @@ def __init__(self, *, rngs: nnx.Rngs):
10301026
def test_children_modules_iterator(self):
10311027
class Foo(nnx.Module):
10321028
def __init__(self, *, rngs: nnx.Rngs):
1033-
self.submodules = nnx.data([
1029+
self.submodules = [
10341030
{'a': nnx.Linear(1, 1, rngs=rngs)},
10351031
{'b': nnx.Conv(1, 1, 1, rngs=rngs)},
1036-
])
1032+
]
10371033
self.linear = nnx.Linear(1, 1, rngs=rngs)
10381034
self.dropout = nnx.Dropout(0.5, rngs=rngs)
10391035

tests/nnx/nn/attention_test.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,9 @@ class Model(nnx.Module):
5151
attention_kwargs: dict
5252

5353
def __init__(self, attention_kwargs, rng):
54-
self.attention_layers = nnx.data([
55-
nnx.MultiHeadAttention(**attention_kwargs, rngs=rng) for i in range(3)
56-
])
57-
54+
self.attention_layers = [
55+
nnx.MultiHeadAttention(**attention_kwargs, rngs=rng) for i in range(3)
56+
]
5857
def __call__(self, x, sow_weights=False):
5958
x = self.attention_layers[0](x, sow_weights=sow_weights)
6059
x = self.attention_layers[1](x)

0 commit comments

Comments
 (0)