Open
Description
It is expected that bias vector can be initialized not only with zeros but the following code fails.
import jax, jax.numpy as jnp
import flax.linen as nn
layer = nn.Conv(features=32,
kernel_size=(3, 3),
use_bias=True,
bias_init=nn.initializers.lecun_normal())
jax.jit(layer.init)(jax.random.PRNGKey(42), jnp.empty((5, 28, 28, 1)))
# File jax/core.py:1969, in NamedShape.__getitem__(self, idx)
# 1967 try:
# 1968 idx = operator.index(idx)
# -> 1969 return self.__positional[idx]
# 1970 except TypeError:
# 1971 pass
#
# IndexError: tuple index out of range
It seems that the issues #1386 is related and maybe #2002 is related as well.