Open
Description
Hi, first off thanks for a great library -- flax is awesome.
I wanted to revisit the documentation to gain a better understanding of flax. In basics there is a section on module parameters.
I wanted to point out that it would appear as though the code seems to not work at the moment.
Here is a stripped version of what is currently in the docs
import flax.linen as nn
import jax.numpy as jnp
import jax.random as random
class SimpleDense(nn.Module):
features: int
kernel_init = nn.initializers.lecun_normal()
@nn.compact
def __call__(self, inputs):
kernel = self.param('kernel',
self.kernel_init, # Initialization function
(inputs.shape[-1], self.features)) # init_args
y = jnp.dot(inputs, kernel)
return y
x = jnp.ones((1, 7))
model = SimpleDense(features=3)
key, init_key = random.split(random.key(123))
params = model.init(init_key, x)
# Error: TypeError: Cannot interpret '7' as a data type
Seems to be something to do with how *init_args
is being unpacked. I tried reproducing similar behaviour with the following
initializer = nn.initializers.glorot_normal()
def foo(rng_key, args):
def initialize():
return nn.initializers.glorot_normal()(rng_key, *args)
return initialize()
foo(random.key(1), (4,5))
# TypeError: Cannot interpret '5' as a data type
But I had trouble navigating the flax codebase as I am unfamiliar with it. Thanks again!