Skip to content

We broke the 1:1 correspondence with attribute names and variable dict names #2100

Open
@marcvanzee

Description

@marcvanzee

from @levskaya:

class Foo(nn.Module):
  def setup(self):
    self.foo = nn.Dense(1, name="bar")
    self.qup = self.param('baz', lambda k: jnp.zeros((1,)))
  def __call__(self, x):
    return self.foo(x) + self.qup
Foo().init(jax.random.PRNGKey(0), jnp.zeros((1,)))
FrozenDict({
    params: {
        baz: DeviceArray([0.], dtype=float32),
        bar: {
            kernel: DeviceArray([[-0.58376074]], dtype=float32),
            bias: DeviceArray([0.], dtype=float32),
        },
    },
})

The above used to break loudly, and it should!

Initial investigation by @avital:

Over the past 1.5 years (I ran a test against every single commit), we actually never had any commit where the following code raised an exception:

  def test_setattr_name_var_agreement_in_setup(self):
    class Foo(nn.Module):
      def setup(self):
        self.qup = self.param('baz', lambda k: 0)
      def __call__(self):
        pass

    Foo(parent=None).init(jax.random.PRNGKey(0))

But we did, in the part, disallow entirely the use of name= for submodules defined within setup, which would have disallowed setting the wrong name for a submodule in setup. We lost that guard with https://github.com/google/flax/pull/976/files

I don't think we ever had tests for the variable attribute correspondence. We do have tests that you can't define two variables with the same name in different collections but not that the name aligns with the attribute being assigned to.

Suggestion from @jheek:

I think you could do something like this to disallowed giving different names in setup.

def __setattr__(self, name, value):
  if any(name in variables[col] for in col):
   assert variables[col] is value, f"A variable named {name} already exist. We don't allow variables and fields to have overlapping names"

Metadata

Metadata

Assignees

Labels

Priority: P1 - soonResponse within 5 business days. Resolution within 30 days. (Assignee required)

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions