31
31
from flax import struct
32
32
from flax .core import Scope , freeze , FrozenDict , tracers
33
33
from flax .linen import compact
34
+ from flax .configurations import use_regular_dict
34
35
import jax
35
36
from jax import random
36
37
from jax .nn import initializers
41
42
# Parse absl flags test_srcdir and test_tmpdir.
42
43
jax .config .parse_flags_with_absl ()
43
44
44
-
45
45
def tree_equals (x , y ):
46
46
return jax .tree_util .tree_all (jax .tree_util .tree_map (operator .eq , x , y ))
47
47
@@ -1140,6 +1140,7 @@ def test(self):
1140
1140
A ().test ()
1141
1141
self .assertFalse (setup_called )
1142
1142
1143
+ @use_regular_dict ()
1143
1144
def test_module_pass_as_attr (self ):
1144
1145
1145
1146
class A (nn .Module ):
@@ -1158,7 +1159,7 @@ def __call__(self, x):
1158
1159
1159
1160
variables = A ().init (random .PRNGKey (0 ), jnp .ones ((1 ,)))
1160
1161
var_shapes = jax .tree_util .tree_map (jnp .shape , variables )
1161
- ref_var_shapes = freeze ( {
1162
+ ref_var_shapes = {
1162
1163
'params' : {
1163
1164
'b' : {
1164
1165
'foo' : {
@@ -1167,9 +1168,10 @@ def __call__(self, x):
1167
1168
}
1168
1169
},
1169
1170
},
1170
- })
1171
+ }
1171
1172
self .assertTrue (tree_equals (var_shapes , ref_var_shapes ))
1172
1173
1174
+ @use_regular_dict ()
1173
1175
def test_module_pass_in_closure (self ):
1174
1176
a = nn .Dense (2 )
1175
1177
@@ -1183,17 +1185,18 @@ def __call__(self, x):
1183
1185
1184
1186
variables = B ().init (random .PRNGKey (0 ), jnp .ones ((1 ,)))
1185
1187
var_shapes = jax .tree_util .tree_map (jnp .shape , variables )
1186
- ref_var_shapes = freeze ( {
1188
+ ref_var_shapes = {
1187
1189
'params' : {
1188
1190
'foo' : {
1189
1191
'bias' : (2 ,),
1190
1192
'kernel' : (1 , 2 ),
1191
1193
}
1192
1194
},
1193
- })
1195
+ }
1194
1196
self .assertTrue (tree_equals (var_shapes , ref_var_shapes ))
1195
1197
self .assertIsNone (a .name )
1196
1198
1199
+ @use_regular_dict ()
1197
1200
def test_toplevel_submodule_adoption (self ):
1198
1201
1199
1202
class Encoder (nn .Module ):
@@ -1233,7 +1236,7 @@ def __call__(self, x):
1233
1236
self .assertEqual (y .shape , (4 , 5 ))
1234
1237
1235
1238
var_shapes = jax .tree_util .tree_map (jnp .shape , variables )
1236
- ref_var_shapes = freeze ( {
1239
+ ref_var_shapes = {
1237
1240
'params' : {
1238
1241
'dense_out' : {
1239
1242
'bias' : (5 ,),
@@ -1246,9 +1249,10 @@ def __call__(self, x):
1246
1249
},
1247
1250
},
1248
1251
},
1249
- })
1252
+ }
1250
1253
self .assertTrue (tree_equals (var_shapes , ref_var_shapes ))
1251
1254
1255
+ @use_regular_dict ()
1252
1256
def test_toplevel_submodule_adoption_pytree (self ):
1253
1257
1254
1258
class A (nn .Module ):
@@ -1276,7 +1280,7 @@ def __call__(self, c, x):
1276
1280
1277
1281
params = B (a_pytree ).init (key , x , x )
1278
1282
unused_y , counters = b .apply (params , x , x , mutable = 'counter' )
1279
- ref_counters = freeze ( {
1283
+ ref_counters = {
1280
1284
'counter' : {
1281
1285
'A_bar' : {
1282
1286
'i' : jnp .array (2.0 ),
@@ -1285,13 +1289,14 @@ def __call__(self, c, x):
1285
1289
'i' : jnp .array (2.0 ),
1286
1290
},
1287
1291
},
1288
- })
1292
+ }
1289
1293
self .assertTrue (
1290
1294
jax .tree_util .tree_all (
1291
1295
jax .tree_util .tree_map (
1292
1296
lambda x , y : np .testing .assert_allclose (x , y , atol = 1e-7 ),
1293
1297
counters , ref_counters )))
1294
1298
1299
+ @use_regular_dict ()
1295
1300
def test_toplevel_submodule_adoption_sharing (self ):
1296
1301
dense = functools .partial (nn .Dense , use_bias = False )
1297
1302
@@ -1323,7 +1328,7 @@ def __call__(self, x):
1323
1328
c = C (a , b )
1324
1329
p = c .init (key , x )
1325
1330
var_shapes = jax .tree_util .tree_map (jnp .shape , p )
1326
- ref_var_shapes = freeze ( {
1331
+ ref_var_shapes = {
1327
1332
'params' : {
1328
1333
'Dense_0' : {
1329
1334
'kernel' : (2 , 2 ),
@@ -1339,9 +1344,10 @@ def __call__(self, x):
1339
1344
},
1340
1345
},
1341
1346
},
1342
- })
1347
+ }
1343
1348
self .assertTrue (tree_equals (var_shapes , ref_var_shapes ))
1344
1349
1350
+ @use_regular_dict ()
1345
1351
def test_toplevel_named_submodule_adoption (self ):
1346
1352
dense = functools .partial (nn .Dense , use_bias = False )
1347
1353
@@ -1369,7 +1375,7 @@ def __call__(self, x):
1369
1375
init_vars = b .init (k , x )
1370
1376
var_shapes = jax .tree_util .tree_map (jnp .shape , init_vars )
1371
1377
if config .flax_preserve_adopted_names :
1372
- ref_var_shapes = freeze ( {
1378
+ ref_var_shapes = {
1373
1379
'params' : {
1374
1380
'foo' : {
1375
1381
'dense' : {
@@ -1380,9 +1386,9 @@ def __call__(self, x):
1380
1386
'kernel' : (4 , 6 ),
1381
1387
},
1382
1388
},
1383
- })
1389
+ }
1384
1390
else :
1385
- ref_var_shapes = freeze ( {
1391
+ ref_var_shapes = {
1386
1392
'params' : {
1387
1393
'a' : {
1388
1394
'dense' : {
@@ -1393,9 +1399,10 @@ def __call__(self, x):
1393
1399
'kernel' : (4 , 6 ),
1394
1400
},
1395
1401
},
1396
- })
1402
+ }
1397
1403
self .assertTrue (tree_equals (var_shapes , ref_var_shapes ))
1398
1404
1405
+ @use_regular_dict ()
1399
1406
def test_toplevel_submodule_pytree_adoption_sharing (self ):
1400
1407
1401
1408
class A (nn .Module ):
@@ -1423,13 +1430,13 @@ def __call__(self, x):
1423
1430
1424
1431
params = b .init (key , x )
1425
1432
_ , counters = b .apply (params , x , mutable = 'counter' )
1426
- ref_counters = freeze ( {
1433
+ ref_counters = {
1427
1434
'counter' : {
1428
1435
'A_bar' : {
1429
1436
'i' : jnp .array (6.0 ),
1430
1437
},
1431
1438
},
1432
- })
1439
+ }
1433
1440
self .assertTrue (tree_equals (counters , ref_counters ))
1434
1441
1435
1442
def test_inner_class_def (self ):
@@ -1650,7 +1657,6 @@ def __call__(self, x):
1650
1657
1651
1658
x = jnp .ones ((3 ,))
1652
1659
variables = Foo ().init (random .PRNGKey (0 ), x )
1653
- variables = variables .unfreeze ()
1654
1660
y = Foo ().apply (variables , x )
1655
1661
self .assertEqual (y .shape , (2 ,))
1656
1662
0 commit comments