Skip to content

Common Initializers Does Not Work with Bias #2749

Open
@daskol

Description

@daskol

It is expected that bias vector can be initialized not only with zeros but the following code fails.

import jax, jax.numpy as jnp
import flax.linen as nn

layer = nn.Conv(features=32,
                kernel_size=(3, 3),
                use_bias=True,
                bias_init=nn.initializers.lecun_normal())

jax.jit(layer.init)(jax.random.PRNGKey(42), jnp.empty((5, 28, 28, 1)))

# File jax/core.py:1969, in NamedShape.__getitem__(self, idx)
#    1967 try:
#    1968   idx = operator.index(idx)
# -> 1969   return self.__positional[idx]
#    1970 except TypeError:
#    1971   pass
#
# IndexError: tuple index out of range

It seems that the issues #1386 is related and maybe #2002 is related as well.

Metadata

Metadata

Assignees

Labels

Priority: P1 - soonResponse within 5 business days. Resolution within 30 days. (Assignee required)

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions