Skip to content

Improve RTD for initializers #1386

Open
@billmark

Description

@billmark

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

Problem you have encountered:

I was trying to exactly mimic the default bias initialization of PyTorch's Linear layer.

The first problem I encountered is that the bias_init arg to nn.Dense is not documented well enough. In particular, it's not clear what the parameters to the callback need to be. Of course I eventually read the source code but that shouldn't be necessary.

The second problem I encountered is that the initializers jax.nnlinitializers.*, e.g. lecun_normal(), can't be used easily (or maybe at all) with bias_init. These initializers expect to be passed a 2D shape (as one would have for kernel_init), but the bias weights are only 1D.

Suggested action:

  1. Document kernel_init and bias_init args better
  2. Figure out how to exactly emulate the PyTorch linear layer's bias_init. Note that you'll need to look at the PyTorch source code because the PyTorch docs aren't precise enough.

Metadata

Metadata

Assignees

Labels

Priority: P1 - soonResponse within 5 business days. Resolution within 30 days. (Assignee required)

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions