Skip to content

Commit 1519161

Browse files
author
Flax Authors
committed
Merge pull request #2898 from chiamp:frozen_dict_tests
PiperOrigin-RevId: 512190368
2 parents 776582c + fea9116 commit 1519161

9 files changed

+94
-49
lines changed

flax/configurations.py

+20
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import os
2525
from jax import config as jax_config
2626

27+
from contextlib import ContextDecorator
28+
2729
# Keep a wrapper at the flax namespace, in case we make our implementation
2830
# in the future.
2931
config = jax_config
@@ -66,6 +68,23 @@ def static_bool_env(varname: str, default: bool) -> bool:
6668
'invalid truth value {!r} for environment {!r}'.format(val, varname))
6769

6870

71+
class use_regular_dict(ContextDecorator):
72+
"""Context decorator for test functions to temporarily use regular dicts
73+
instead of FrozenDicts.
74+
75+
This is a temporary feature flag to help migrate to FrozenDicts. Returning
76+
FrozenDicts will be deprecated and removed in the future.
77+
"""
78+
def __enter__(self):
79+
self._old_value = config.flax_return_frozendict # save current env value
80+
config.update('flax_return_frozendict', False) # return regular dicts
81+
return self
82+
83+
def __exit__(self, *exc):
84+
config.update('flax_return_frozendict', self._old_value) # switch back to old env value
85+
return False
86+
87+
6988
# Flax Global Configuration Variables:
7089

7190
# Whether to use the lazy rng implementation.
@@ -96,6 +115,7 @@ def static_bool_env(varname: str, default: bool) -> bool:
96115
default=False,
97116
help=("When adopting outside modules, don't clobber existing names."))
98117

118+
#TODO(marcuschiam): remove this feature flag once regular dict migration is complete
99119
flax_return_frozendict = define_bool_state(
100120
name='return_frozendict',
101121
default=True,

flax/linen/summary.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -310,15 +310,15 @@ def _get_path_variables(path: Tuple[str, ...], variables: FrozenVariableDict) ->
310310
path_variables = {}
311311

312312
for collection in variables:
313-
collection_variables = variables[collection]
313+
collection_variables = jax.tree_util.tree_map(lambda x: x, variables[collection]) # make a deep copy
314314
for name in path:
315315
if name not in collection_variables:
316316
collection_variables = None
317317
break
318318
collection_variables = collection_variables[name]
319319

320320
if collection_variables is not None:
321-
path_variables[collection] = collection_variables.unfreeze()
321+
path_variables[collection] = collection_variables
322322

323323
return path_variables
324324

tests/core/core_lift_test.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
import operator
1616
from flax import errors
17-
from flax.core import Scope, init, apply, lift, nn, FrozenDict, unfreeze
17+
from flax.core import Scope, init, apply, lift, nn, FrozenDict, unfreeze, copy
18+
from flax.configurations import use_regular_dict
1819

1920
import jax
2021
from jax import random
@@ -25,6 +26,7 @@
2526

2627
from absl.testing import absltest
2728

29+
2830
class LiftTest(absltest.TestCase):
2931

3032
def test_aliasing(self):
@@ -155,11 +157,12 @@ def false_fn(scope, x):
155157

156158
x = jnp.ones((1, 3))
157159
y1, vars = init(f)(random.PRNGKey(0), x, True)
158-
self.assertEqual(vars['state'].unfreeze(), {'true_count': 1, 'false_count': 0})
160+
self.assertEqual(vars['state'], {'true_count': 1, 'false_count': 0})
159161
y2, vars = apply(f, mutable="state")(vars, x, False)
160-
self.assertEqual(vars['state'].unfreeze(), {'true_count': 1, 'false_count': 1})
162+
self.assertEqual(vars['state'], {'true_count': 1, 'false_count': 1})
161163
np.testing.assert_allclose(y1, -y2)
162164

165+
@use_regular_dict()
163166
def test_switch(self):
164167
def f(scope, x, index):
165168
scope.variable('state', 'a_count', lambda: 0)
@@ -182,14 +185,14 @@ def c_fn(scope, x):
182185

183186
x = jnp.ones((1, 3))
184187
y1, vars = init(f)(random.PRNGKey(0), x, 0)
185-
self.assertEqual(vars['state'].unfreeze(), {'a_count': 1, 'b_count': 0, 'c_count': 0})
188+
self.assertEqual(vars['state'], {'a_count': 1, 'b_count': 0, 'c_count': 0})
186189
y2, updates = apply(f, mutable="state")(vars, x, 1)
187-
vars = vars.copy(updates)
188-
self.assertEqual(vars['state'].unfreeze(), {'a_count': 1, 'b_count': 1, 'c_count': 0})
190+
vars = copy(vars, updates)
191+
self.assertEqual(vars['state'], {'a_count': 1, 'b_count': 1, 'c_count': 0})
189192
np.testing.assert_allclose(y1, -y2)
190193
y3, updates = apply(f, mutable="state")(vars, x, 2)
191-
vars = vars.copy(updates)
192-
self.assertEqual(vars['state'].unfreeze(), {'a_count': 1, 'b_count': 1, 'c_count': 1})
194+
vars = copy(vars, updates)
195+
self.assertEqual(vars['state'], {'a_count': 1, 'b_count': 1, 'c_count': 1})
193196
np.testing.assert_allclose(y1, y3)
194197

195198
def test_subscope_var_aliasing(self):

tests/core/core_meta_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def body(scope, x):
110110
return c
111111

112112
_, variables = init(f)(random.PRNGKey(0), jnp.zeros((8, 3)))
113-
boxed_shapes = jax.tree_map(jnp.shape, variables['params'].unfreeze())
113+
boxed_shapes = jax.tree_map(jnp.shape, variables['params'])
114114
self.assertEqual(boxed_shapes, {
115115
'kernel': meta.Partitioned((8, 3, 3), ('layers', 'in', 'out')),
116116
'bias': (8, 3),

tests/linen/linen_attention_test.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
from flax import linen as nn
2121
from flax import jax_utils
22+
from flax.core import pop
23+
from flax.configurations import use_regular_dict
2224

2325
import jax
2426
from jax import lax
@@ -31,7 +33,6 @@
3133
# Parse absl flags test_srcdir and test_tmpdir.
3234
jax.config.parse_flags_with_absl()
3335

34-
3536
class AttentionTest(parameterized.TestCase):
3637

3738
def test_multihead_self_attention(self):
@@ -102,6 +103,7 @@ def test_causal_mask_1d(self):
102103
np.testing.assert_allclose(mask_1d, mask_1d_simple,)
103104

104105
@parameterized.parameters([((5,), (1,)), ((6, 5), (2,))])
106+
@use_regular_dict()
105107
def test_decoding(self, spatial_shape, attn_dims):
106108
bs = 2
107109
num_heads = 3
@@ -119,7 +121,7 @@ def test_decoding(self, spatial_shape, attn_dims):
119121
decode_module = module.clone(decode=True)
120122

121123
initial_vars = decode_module.init(key2, inputs)
122-
state, params = initial_vars.pop('params')
124+
state, params = pop(initial_vars, 'params')
123125
causal_mask = nn.attention.make_causal_mask(jnp.ones((bs,) + spatial_shape))
124126
y_ref = jax.jit(lambda x, y: module.apply(initial_vars, x, y))(
125127
inputs, causal_mask)

tests/linen/linen_meta_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def body(_, c):
144144
x = jnp.ones((8, 128))
145145
spec = nn.get_partition_spec(
146146
jax.eval_shape(model.init, random.PRNGKey(0), x))
147-
self.assertEqual(spec.unfreeze(), {
147+
self.assertEqual(spec, {
148148
'params': {
149149
'MLP_0': {
150150
'Dense_0': {

tests/linen/linen_module_test.py

+24-18
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from flax import struct
3232
from flax.core import Scope, freeze, FrozenDict, tracers
3333
from flax.linen import compact
34+
from flax.configurations import use_regular_dict
3435
import jax
3536
from jax import random
3637
from jax.nn import initializers
@@ -41,7 +42,6 @@
4142
# Parse absl flags test_srcdir and test_tmpdir.
4243
jax.config.parse_flags_with_absl()
4344

44-
4545
def tree_equals(x, y):
4646
return jax.tree_util.tree_all(jax.tree_util.tree_map(operator.eq, x, y))
4747

@@ -1140,6 +1140,7 @@ def test(self):
11401140
A().test()
11411141
self.assertFalse(setup_called)
11421142

1143+
@use_regular_dict()
11431144
def test_module_pass_as_attr(self):
11441145

11451146
class A(nn.Module):
@@ -1158,7 +1159,7 @@ def __call__(self, x):
11581159

11591160
variables = A().init(random.PRNGKey(0), jnp.ones((1,)))
11601161
var_shapes = jax.tree_util.tree_map(jnp.shape, variables)
1161-
ref_var_shapes = freeze({
1162+
ref_var_shapes = {
11621163
'params': {
11631164
'b': {
11641165
'foo': {
@@ -1167,9 +1168,10 @@ def __call__(self, x):
11671168
}
11681169
},
11691170
},
1170-
})
1171+
}
11711172
self.assertTrue(tree_equals(var_shapes, ref_var_shapes))
11721173

1174+
@use_regular_dict()
11731175
def test_module_pass_in_closure(self):
11741176
a = nn.Dense(2)
11751177

@@ -1183,17 +1185,18 @@ def __call__(self, x):
11831185

11841186
variables = B().init(random.PRNGKey(0), jnp.ones((1,)))
11851187
var_shapes = jax.tree_util.tree_map(jnp.shape, variables)
1186-
ref_var_shapes = freeze({
1188+
ref_var_shapes = {
11871189
'params': {
11881190
'foo': {
11891191
'bias': (2,),
11901192
'kernel': (1, 2),
11911193
}
11921194
},
1193-
})
1195+
}
11941196
self.assertTrue(tree_equals(var_shapes, ref_var_shapes))
11951197
self.assertIsNone(a.name)
11961198

1199+
@use_regular_dict()
11971200
def test_toplevel_submodule_adoption(self):
11981201

11991202
class Encoder(nn.Module):
@@ -1233,7 +1236,7 @@ def __call__(self, x):
12331236
self.assertEqual(y.shape, (4, 5))
12341237

12351238
var_shapes = jax.tree_util.tree_map(jnp.shape, variables)
1236-
ref_var_shapes = freeze({
1239+
ref_var_shapes = {
12371240
'params': {
12381241
'dense_out': {
12391242
'bias': (5,),
@@ -1246,9 +1249,10 @@ def __call__(self, x):
12461249
},
12471250
},
12481251
},
1249-
})
1252+
}
12501253
self.assertTrue(tree_equals(var_shapes, ref_var_shapes))
12511254

1255+
@use_regular_dict()
12521256
def test_toplevel_submodule_adoption_pytree(self):
12531257

12541258
class A(nn.Module):
@@ -1276,7 +1280,7 @@ def __call__(self, c, x):
12761280

12771281
params = B(a_pytree).init(key, x, x)
12781282
unused_y, counters = b.apply(params, x, x, mutable='counter')
1279-
ref_counters = freeze({
1283+
ref_counters = {
12801284
'counter': {
12811285
'A_bar': {
12821286
'i': jnp.array(2.0),
@@ -1285,13 +1289,14 @@ def __call__(self, c, x):
12851289
'i': jnp.array(2.0),
12861290
},
12871291
},
1288-
})
1292+
}
12891293
self.assertTrue(
12901294
jax.tree_util.tree_all(
12911295
jax.tree_util.tree_map(
12921296
lambda x, y: np.testing.assert_allclose(x, y, atol=1e-7),
12931297
counters, ref_counters)))
12941298

1299+
@use_regular_dict()
12951300
def test_toplevel_submodule_adoption_sharing(self):
12961301
dense = functools.partial(nn.Dense, use_bias=False)
12971302

@@ -1323,7 +1328,7 @@ def __call__(self, x):
13231328
c = C(a, b)
13241329
p = c.init(key, x)
13251330
var_shapes = jax.tree_util.tree_map(jnp.shape, p)
1326-
ref_var_shapes = freeze({
1331+
ref_var_shapes = {
13271332
'params': {
13281333
'Dense_0': {
13291334
'kernel': (2, 2),
@@ -1339,9 +1344,10 @@ def __call__(self, x):
13391344
},
13401345
},
13411346
},
1342-
})
1347+
}
13431348
self.assertTrue(tree_equals(var_shapes, ref_var_shapes))
13441349

1350+
@use_regular_dict()
13451351
def test_toplevel_named_submodule_adoption(self):
13461352
dense = functools.partial(nn.Dense, use_bias=False)
13471353

@@ -1369,7 +1375,7 @@ def __call__(self, x):
13691375
init_vars = b.init(k, x)
13701376
var_shapes = jax.tree_util.tree_map(jnp.shape, init_vars)
13711377
if config.flax_preserve_adopted_names:
1372-
ref_var_shapes = freeze({
1378+
ref_var_shapes = {
13731379
'params': {
13741380
'foo': {
13751381
'dense': {
@@ -1380,9 +1386,9 @@ def __call__(self, x):
13801386
'kernel': (4, 6),
13811387
},
13821388
},
1383-
})
1389+
}
13841390
else:
1385-
ref_var_shapes = freeze({
1391+
ref_var_shapes = {
13861392
'params': {
13871393
'a': {
13881394
'dense': {
@@ -1393,9 +1399,10 @@ def __call__(self, x):
13931399
'kernel': (4, 6),
13941400
},
13951401
},
1396-
})
1402+
}
13971403
self.assertTrue(tree_equals(var_shapes, ref_var_shapes))
13981404

1405+
@use_regular_dict()
13991406
def test_toplevel_submodule_pytree_adoption_sharing(self):
14001407

14011408
class A(nn.Module):
@@ -1423,13 +1430,13 @@ def __call__(self, x):
14231430

14241431
params = b.init(key, x)
14251432
_, counters = b.apply(params, x, mutable='counter')
1426-
ref_counters = freeze({
1433+
ref_counters = {
14271434
'counter': {
14281435
'A_bar': {
14291436
'i': jnp.array(6.0),
14301437
},
14311438
},
1432-
})
1439+
}
14331440
self.assertTrue(tree_equals(counters, ref_counters))
14341441

14351442
def test_inner_class_def(self):
@@ -1650,7 +1657,6 @@ def __call__(self, x):
16501657

16511658
x = jnp.ones((3,))
16521659
variables = Foo().init(random.PRNGKey(0), x)
1653-
variables = variables.unfreeze()
16541660
y = Foo().apply(variables, x)
16551661
self.assertEqual(y.shape, (2,))
16561662

0 commit comments

Comments
 (0)