Skip to content

L1/L2 weight regularization #1654

Answered by cgarciae
dnajera27 asked this question in Q&A
Discussion options

You must be logged in to vote

Hey! You can do something like this to get global L2 regularization:

def l2_loss(x, alpha):
    return alpha * (x ** 2).sum()

def loss_fn(...):
    ...
    loss = ...
    loss += sum(
        l2_loss(w, alpha=0.001) 
        for w in jax.tree_leaves(variables["params"])
    )

Replies: 5 comments

Comment options

You must be logged in to vote
0 replies
Answer selected by dnajera27
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
6 participants