2828
2929A = TypeVar ('A' )
3030
31-
3231class PytreeTest (absltest .TestCase ):
3332 def test_pytree (self ):
3433 class Foo (nnx .Pytree ):
@@ -1001,13 +1000,10 @@ def __call__(self, x, *, rngs: nnx.Rngs):
10011000
10021001 assert isinstance (y , jax .Array )
10031002
1004- def test_modules_iterator (self ):
1005- class Foo (nnx .Module ):
1006- def __init__ (self , * , rngs : nnx .Rngs ):
1007- self .submodules = nnx .data ([
1003+ self .submodules = [
10081004 {'a' : nnx .Linear (1 , 1 , rngs = rngs )},
10091005 {'b' : nnx .Conv (1 , 1 , 1 , rngs = rngs )},
1010- ])
1006+ ]
10111007 self .linear = nnx .Linear (1 , 1 , rngs = rngs )
10121008 self .dropout = nnx .Dropout (0.5 , rngs = rngs )
10131009
@@ -1030,10 +1026,10 @@ def __init__(self, *, rngs: nnx.Rngs):
10301026 def test_children_modules_iterator (self ):
10311027 class Foo (nnx .Module ):
10321028 def __init__ (self , * , rngs : nnx .Rngs ):
1033- self .submodules = nnx . data ( [
1029+ self .submodules = [
10341030 {'a' : nnx .Linear (1 , 1 , rngs = rngs )},
10351031 {'b' : nnx .Conv (1 , 1 , 1 , rngs = rngs )},
1036- ])
1032+ ]
10371033 self .linear = nnx .Linear (1 , 1 , rngs = rngs )
10381034 self .dropout = nnx .Dropout (0.5 , rngs = rngs )
10391035
0 commit comments