Implementing Weight Decay Masking with nnx.Optimizer and Filters #4737
Unanswered
cadazar
asked this question in
Show and tell
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi, I have been using the NNX API to train a custom language model, needing to apply weight decay only to certain parameters (excluding bias/normalization/embeddings) while using
nnx.Optimizer
. I found that creating twonnx.Optimizer
instances with different base optimizers (e.g.optax.adamw
for WD,optax.adam
for no WD) and applying them to filtered parameter subsets usingnnx.filterlib
works well without any significant overhead:Hope this can come of use to someone out there.
Beta Was this translation helpful? Give feedback.
All reactions