Description
from @levskaya:
class Foo(nn.Module):
def setup(self):
self.foo = nn.Dense(1, name="bar")
self.qup = self.param('baz', lambda k: jnp.zeros((1,)))
def __call__(self, x):
return self.foo(x) + self.qup
Foo().init(jax.random.PRNGKey(0), jnp.zeros((1,)))
FrozenDict({
params: {
baz: DeviceArray([0.], dtype=float32),
bar: {
kernel: DeviceArray([[-0.58376074]], dtype=float32),
bias: DeviceArray([0.], dtype=float32),
},
},
})
The above used to break loudly, and it should!
Initial investigation by @avital:
Over the past 1.5 years (I ran a test against every single commit), we actually never had any commit where the following code raised an exception:
def test_setattr_name_var_agreement_in_setup(self):
class Foo(nn.Module):
def setup(self):
self.qup = self.param('baz', lambda k: 0)
def __call__(self):
pass
Foo(parent=None).init(jax.random.PRNGKey(0))
But we did, in the part, disallow entirely the use of name=
for submodules defined within setup
, which would have disallowed setting the wrong name for a submodule in setup. We lost that guard with https://github.com/google/flax/pull/976/files
I don't think we ever had tests for the variable attribute correspondence. We do have tests that you can't define two variables with the same name in different collections but not that the name aligns with the attribute being assigned to.
Suggestion from @jheek:
I think you could do something like this to disallowed giving different names in setup.
def __setattr__(self, name, value):
if any(name in variables[col] for in col):
assert variables[col] is value, f"A variable named {name} already exist. We don't allow variables and fields to have overlapping names"