Skip to content

*Module Parameters* section of docs is outdated. #3761

Open
@PaulScemama

Description

@PaulScemama

Hi, first off thanks for a great library -- flax is awesome.

I wanted to revisit the documentation to gain a better understanding of flax. In basics there is a section on module parameters.

I wanted to point out that it would appear as though the code seems to not work at the moment.

Here is a stripped version of what is currently in the docs

import flax.linen as nn
import jax.numpy as jnp
import jax.random as random


class SimpleDense(nn.Module):
  features: int
  kernel_init = nn.initializers.lecun_normal()

  @nn.compact
  def __call__(self, inputs):
    kernel = self.param('kernel',
                        self.kernel_init, # Initialization function
                        (inputs.shape[-1], self.features))  # init_args
    y = jnp.dot(inputs, kernel)
    return y

x = jnp.ones((1, 7))
model = SimpleDense(features=3)
key, init_key = random.split(random.key(123))

params = model.init(init_key, x)
# Error: TypeError: Cannot interpret '7' as a data type

Seems to be something to do with how *init_args is being unpacked. I tried reproducing similar behaviour with the following

initializer = nn.initializers.glorot_normal()

def foo(rng_key, args):
    
    def initialize():
        return nn.initializers.glorot_normal()(rng_key, *args)

    return initialize()

foo(random.key(1), (4,5))
# TypeError: Cannot interpret '5' as a data type

But I had trouble navigating the flax codebase as I am unfamiliar with it. Thanks again!

Metadata

Metadata

Assignees

Labels

Priority: P1 - soonResponse within 5 business days. Resolution within 30 days. (Assignee required)

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions