Conversation
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
c495dc1 to
b929529
Compare
34d7c20 to
444c6b6
Compare
fa523a9 to
5c12190
Compare
3172fdb to
b939636
Compare
f894a0d to
b37c527
Compare
|
I feel we could simplify the intro by doing the following:
model = nnx.Sequential(
nnx.Linear(2,8, rngs=rngs),
nnx.relu,
nnx.Linear(8,8, rngs=rngs),
)
optimizer = nnx.Optimizer(
model,
tx=optax.adam(1e-3),
wrt=nnx.Param)
...
@nnx.jit
def train_step(model, optimizer, ema, x, y):
loss_fn = lambda m, x, y: jnp.sum((m(x) - y) ** 2)
loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
optimizer.update(model, grads)
ema.update(model)
return loss |
docs_nnx/guides/opt_cookbook.rst
Outdated
| model = nnx_model(rngs) | ||
| state = nnx.state(model, nnx.Param) | ||
| rates = {'kernel': optax.adam(1e-3), 'bias': optax.adam(1e-2)} | ||
| param_tys = nnx.map_state(lambda p, v: list(p)[-1], state) |
There was a problem hiding this comment.
I thin this is enough:
| param_tys = nnx.map_state(lambda p, v: list(p)[-1], state) | |
| param_tys = nnx.map_state(lambda p, v: p[-1], state) |
There was a problem hiding this comment.
We could also use jax.tree. map_with_path as in the JAX example.
docs_nnx/guides/opt_cookbook.rst
Outdated
| axis_types=(AxisType.Explicit, AxisType.Explicit)) | ||
| jax.set_mesh(mesh) | ||
|
|
||
| ghost_model = jax.eval_shape(lambda: nnx_model(nnx.Rngs(0), out_sharding=P('x', 'y'))) |
There was a problem hiding this comment.
instead of creating this fake model it would be a good opportunity to create the optimizer_sharding API on Variable before finishing this guide.
|
After fully reading the guide I'm getting the sense that having the JAX versions makes explanations a bit longer and slightly harder to understand (cause you have to mentally filter for the version you are interested in) and having the JAX version doesn't necessarily make understanding the NNX version easier. |
Fair enough! I'll convert it to nnx-only. |
694bd84 to
76f8752
Compare
76f8752 to
f73edbd
Compare
What does this PR do?
This PR adds a guide that shows some common techniques for working with Flax models during optimization. These include:
This document emphasizes a style as close to pure jax as possible: to that end, it shows how the flax version of each technique only requires minor deviation from the often more intuitive pure-jax version.
Warnings: