Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,18 @@ def compile(
<https://www.paddlepaddle.org.cn/documentation/docs/en/develop/api/paddle/optimizer/lr/InverseTimeDecay_en.html>`_:
("inverse time", gamma)

- For backend JAX:

- `linear_schedule
<https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.schedules.linear_schedule>`_:
("linear", end_value, transition_steps)
- `cosine_decay_schedule
<https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.schedules.cosine_decay_schedule>`_:
("cosine", decay_steps, alpha)
- `exponential_decay
<https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.schedules.exponential_decay>`_:
("exponential", transition_steps, decay_rate)

loss_weights: A list specifying scalar coefficients (Python floats) to
weight the loss contributions. The loss value that will be minimized by
the model will then be the weighted sum of all individual losses,
Expand Down Expand Up @@ -417,8 +429,7 @@ def _compile_jax(self, lr, loss_fn, decay):
var.value for var in self.external_trainable_variables
]
self.params = [self.net.params, external_trainable_variables_val]
# TODO: learning rate decay
self.opt = optimizers.get(self.opt_name, learning_rate=lr)
self.opt = optimizers.get(self.opt_name, learning_rate=lr, decay=decay)
self.opt_state = self.opt.init(self.params)

@jax.jit
Expand Down
8 changes: 7 additions & 1 deletion deepxde/optimizers/jax/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,13 @@ def get(optimizer, learning_rate=None, decay=None):
def _get_learningrate(lr, decay):
if decay is None:
return lr
# TODO: add optax's optimizer schedule
if decay[0] == "linear":
return optax.linear_schedule(lr, decay[1], decay[2])
if decay[0] == "cosine":
return optax.cosine_decay_schedule(lr, decay[1], decay[2])
if decay[0] == "exponential":
return optax.exponential_decay(lr, decay[1], decay[2])

raise NotImplementedError(
f"{decay[0]} learning rate decay to be implemented for backend jax."
)