-
Notifications
You must be signed in to change notification settings - Fork 684
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added initializers doc #2776
base: main
Are you sure you want to change the base?
Added initializers doc #2776
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
8564591
to
bbbbd36
Compare
Codecov Report
@@ Coverage Diff @@
## main #2776 +/- ##
=======================================
Coverage 81.24% 81.24%
=======================================
Files 53 53
Lines 5663 5663
=======================================
Hits 4601 4601
Misses 1062 1062 Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
bbbbd36
to
6c112a0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this guide, really nice! I suppose the next step is to improve docstrings? For instance, linking to this guide from docstrings seems useful.
docs/guides/initializers.md
Outdated
|
||
`Initializers` are functions that can be passed as optional arguments to the kernel initializer (`kernel_init`) and the bias initializer (`bias_init`) if you want to specify how the parameters of a Module layer are initialized. A full list of Flax initializers can be found [here](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#module-flax.linen.initializers), and are in fact, the same as the [JAX initializers](https://jax.readthedocs.io/en/latest/jax.nn.initializers.html). | ||
|
||
The default kernel initializer is [`flax.linen.initializers.lecun_normal`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.initializers.lecun_normal.html) and the default bias initializer is [`flax.linen.initializers.zeros`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.initializers.zeros.html). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be useful if you could explain why we are using these initializers. There was some confusion about this before (see #215). Maybe open a thread on flax-core to ask?
docs/guides/initializers.md
Outdated
|
||
+++ {"id": "3Kglqd9vuxTG"} | ||
|
||
To maintain consistency, all `Initializer` functions that are passed to the `kernel_init` and `bias_init` arguments **must follow the function signature: `[PRNGKey, Shape, Dtype] -> Array`**. Most functions in the [Flax initializer list](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#module-flax.linen.initializers) are **builder functions** and build an `Initializer` function that follows this function signature. The two exceptions are [`flax.linen.initializers.zeros`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.initializers.zeros.html) and [`flax.linen.initializers.ones`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.initializers.ones.html), which are already `Initializer` functions that follow the function signature. This is why in the above example, we must call `lecun_normal()` to build an `Initializer` function, whereas we can directly use `zeros` since it's already an `Initializer` function. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking of "prefer simple APIs over more documentation", I am wondering whether it wouldn't be easier to just add two fake builder functions for zeros
and ones
as follows:
def zeros(dtype: DTypeLikeInexact = jnp.float_) -> Array:
def init(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = dtype) -> Array:
del key
return jnp.zeros(shape, dtype)
Then we don't need this complicated explanation and people can simply use all initializers consistently. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a great idea! It would make things much more consistent. The question is what we should do with the original flax.linen.initializers.zeros and flax.linen.initializers.ones. If we replace them with the fake builder functions you suggested, then wouldn't this break code using these initializers? On the other hand if we leave them, then I think users may get confused on what the difference is; i.e. which they should use, which ones to call versus which ones to use explicitly, etc.
|
||
+++ {"id": "S4X_xHHk-b4V"} | ||
|
||
## `Initializer` restrictions for `bias_init` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Maybe we should link to this from our bias_init
docstrings as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean all docstrings that contain a description for bias_init
(like here)? Something like: bias_init: initializer function for the bias. To see restrictions on valid initializers, refer to our guide: https://flax.readthedocs.io/en/latest/guides/initializers.html#initializer-restrictions-for-bias-init
6c112a0
to
d03db7f
Compare
d03db7f
to
eed5cd9
Compare
Resolves #2749 and #1386.
Created initializers documentation. View the doc here.