-
Notifications
You must be signed in to change notification settings - Fork 894
JAX BACKEND: add regularizer #1968
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Why the build failed? |
|
According to github copilot: Suggested Solution Check the URL: Ensure that the URL for pybind11 is correct and accessible. If the URL is outdated, update it to the latest version. Update dm-tree: Ensure that you are using the latest version of the dm-tree package, as it may have fixed the issue with the dependency. Retry the Build: Sometimes, transient network issues can cause such errors. Retrying the build might resolve the issue. Modify the Workflow File: If the issue persists, modify the workflow to include the installation of pybind11 before building dm-tree. So this is probably a network issue during the build |
a006409 to
1dacd63
Compare
|
I've made and reverted a dummy commit to trigger the build again: it worked this time |
8395dbe to
fb3d140
Compare
1202d7d to
107f970
Compare
Add regularizer for JAX BACKEND.
For Jax adding the regularization is straightforward: you apply the regularization function to the nn_params and add it to the loss. see google/flax#1654
I tried to stick to TensorFlow and paddle architecture: regularizer is a net property init from the regularization parameter. (even though regularizer is not used in the net here)
In that case, the regularizer is simply the function to be applied to the net parameters to compute the loss in model.py