-
I was wondering if there's a preferred way of performing l1/l2 regularization on a neural network weights in Flax? I could not find an example in the documentation but I was basically trying to replicate what the kernel_regularizer method does in Tensorflow. Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 5 comments
-
Hey! You can do something like this to get global L2 regularization: def l2_loss(x, alpha):
return alpha * (x ** 2).sum()
def loss_fn(...):
...
loss = ...
loss += sum(
l2_loss(w, alpha=0.001)
for w in jax.tree_leaves(variables["params"])
) |
Beta Was this translation helpful? Give feedback.
-
I refer @cgarciae approach but if you need to control how/which params get regularized from inside the model it's easier to sow the regularization loss:
|
Beta Was this translation helpful? Give feedback.
-
As a supplement to @cgarciae's answer def find_params_by_node_name(params, node_name):
from typing import Iterable
def _is_leaf_fun(x):
if isinstance(x, Iterable) and jax.tree_util.all_leaves(x.values()):
return True
return False
def _get_key_finder(key):
def _finder(x):
value = x.get(key)
return None if value is None else {key: value}
return _finder
filtered_params = jax.tree_map(_get_key_finder(node_name), params, is_leaf=_is_leaf_fun)
filtered_params = [x for x in jax.tree_leaves(filtered_params) if x is not None]
return filtered_params
model = MyModel()
params = model.init(...)['params']
kernels = find_params_by_node_name(params, 'kernel') |
Beta Was this translation helpful? Give feedback.
-
Can you give a full example of using the l2_loss update? |
Beta Was this translation helpful? Give feedback.
-
specifically with Convolutions, I was testing out WeightNorm and came up with a wrapper that might work.
basically wraps another conv cell and uses a new kernel param for the wrapped convolution so I can perform WS and then invoke I like the idea of modeling it into the Modules themselves but it's just @cgarciae's implementation fits the flax paradigms better |
Beta Was this translation helpful? Give feedback.
Hey! You can do something like this to get global L2 regularization: