Description
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
Problem you have encountered:
I was trying to exactly mimic the default bias initialization of PyTorch's Linear layer.
The first problem I encountered is that the bias_init arg to nn.Dense is not documented well enough. In particular, it's not clear what the parameters to the callback need to be. Of course I eventually read the source code but that shouldn't be necessary.
The second problem I encountered is that the initializers jax.nnlinitializers.*, e.g. lecun_normal(), can't be used easily (or maybe at all) with bias_init. These initializers expect to be passed a 2D shape (as one would have for kernel_init), but the bias weights are only 1D.
Suggested action:
- Document kernel_init and bias_init args better
- Figure out how to exactly emulate the PyTorch linear layer's bias_init. Note that you'll need to look at the PyTorch source code because the PyTorch docs aren't precise enough.