Skip to content

Add Optimization Cookbook#5117

Open
samanklesaria wants to merge 4 commits intogoogle:mainfrom
samanklesaria:opt_cookbook
Open

Add Optimization Cookbook#5117
samanklesaria wants to merge 4 commits intogoogle:mainfrom
samanklesaria:opt_cookbook

Conversation

@samanklesaria
Copy link
Collaborator

@samanklesaria samanklesaria commented Nov 28, 2025

What does this PR do?

This PR adds a guide that shows some common techniques for working with Flax models during optimization. These include:

  • Calculation of Exponential Moving Averages
  • Optimizing only a low rank addition to certain weights (LORA)
  • Using different learning rates for different parameters to implement the maximal update parameterization
  • Using second order optimizers like LBFGS.
  • Specifying sharding for optimization state that differs from that of parameter state
  • Gradient accumulation

This document emphasizes a style as close to pure jax as possible: to that end, it shows how the flax version of each technique only requires minor deviation from the often more intuitive pure-jax version.

Warnings:

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@samanklesaria samanklesaria force-pushed the opt_cookbook branch 3 times, most recently from c495dc1 to b929529 Compare December 1, 2025 23:53
@samanklesaria samanklesaria force-pushed the opt_cookbook branch 5 times, most recently from 34d7c20 to 444c6b6 Compare December 9, 2025 22:27
@samanklesaria samanklesaria marked this pull request as ready for review January 6, 2026 20:37
@samanklesaria samanklesaria force-pushed the opt_cookbook branch 2 times, most recently from f894a0d to b37c527 Compare January 20, 2026 19:56
@cgarciae
Copy link
Collaborator

I feel we could simplify the intro by doing the following:

  1. Define a single model at the begining, simply reuse it on all examples (its just a guide).
  2. Inline the loss function.
model = nnx.Sequential(
  nnx.Linear(2,8, rngs=rngs),
  nnx.relu,
  nnx.Linear(8,8, rngs=rngs),
)

optimizer = nnx.Optimizer(
  model,
  tx=optax.adam(1e-3),
  wrt=nnx.Param)

...

@nnx.jit
def train_step(model, optimizer, ema, x, y):
  loss_fn = lambda m, x, y: jnp.sum((m(x) - y) ** 2)
  loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
  optimizer.update(model, grads)
  ema.update(model)
  return loss

model = nnx_model(rngs)
state = nnx.state(model, nnx.Param)
rates = {'kernel': optax.adam(1e-3), 'bias': optax.adam(1e-2)}
param_tys = nnx.map_state(lambda p, v: list(p)[-1], state)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thin this is enough:

Suggested change
param_tys = nnx.map_state(lambda p, v: list(p)[-1], state)
param_tys = nnx.map_state(lambda p, v: p[-1], state)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also use jax.tree. map_with_path as in the JAX example.

axis_types=(AxisType.Explicit, AxisType.Explicit))
jax.set_mesh(mesh)

ghost_model = jax.eval_shape(lambda: nnx_model(nnx.Rngs(0), out_sharding=P('x', 'y')))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of creating this fake model it would be a good opportunity to create the optimizer_sharding API on Variable before finishing this guide.

@cgarciae
Copy link
Collaborator

After fully reading the guide I'm getting the sense that having the JAX versions makes explanations a bit longer and slightly harder to understand (cause you have to mentally filter for the version you are interested in) and having the JAX version doesn't necessarily make understanding the NNX version easier.

@samanklesaria
Copy link
Collaborator Author

After fully reading the guide I'm getting the sense that having the JAX versions makes explanations a bit longer and slightly harder to understand (cause you have to mentally filter for the version you are interested in) and having the JAX version doesn't necessarily make understanding the NNX version easier.

Fair enough! I'll convert it to nnx-only.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants