Skip to content

Linen: Consider raising an error when reading variables from submodule before initialization? #513

Open
@avital

Description

@avital

Within a module that uses shape-inference (as most of the built-in Linen modules do), this code is fine:

class MyModule(nn.Module):
  def __call__(self, x):
    conv = nn.Conv(features=3)
    y = conv(x)
    params = conv.params()

But if you instead do:

class MyModule(nn.Module):
  def __call__(self, x):
    conv = nn.Conv(features=3)
    params = conv.params()
    y = conv(x)

Then I believe you get an empty params dict (as the parameters of conv are only initialized once the input shape is known)

Seems like users may be surprised about this, so instead we could just raise an error if the variables are empty when reading them clarifying what is happening, to guide users to the right direction.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Priority: P2 - no scheduleBest effort response and resolution. We have no plan to work on this at the moment.Status: pull requests welcomeWe agree with the direction proposed, feel free to give it a shot and file a pull request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions