Skip to content

Commit 5722ebb

Browse files
committed
Backend JAX: learning rate decay with Optax
1 parent b79d2fd commit 5722ebb

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

deepxde/model.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,12 @@ def compile(
107107
<https://www.paddlepaddle.org.cn/documentation/docs/en/develop/api/paddle/optimizer/lr/InverseTimeDecay_en.html>`_:
108108
("inverse time", gamma)
109109
110+
- For backend JAX:
111+
112+
- `linear_schedule <https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.schedules.linear_schedule>`_: ("linear", end_value, transition_steps)
113+
- `cosine_decay_schedule <https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.schedules.cosine_decay_schedule>`_: ("cosine", decay_steps, alpha)
114+
- `exponential_decay <https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.schedules.exponential_decay>`_: ("exponential", transition_steps, decay_rate)
115+
110116
loss_weights: A list specifying scalar coefficients (Python floats) to
111117
weight the loss contributions. The loss value that will be minimized by
112118
the model will then be the weighted sum of all individual losses,
@@ -417,8 +423,7 @@ def _compile_jax(self, lr, loss_fn, decay):
417423
var.value for var in self.external_trainable_variables
418424
]
419425
self.params = [self.net.params, external_trainable_variables_val]
420-
# TODO: learning rate decay
421-
self.opt = optimizers.get(self.opt_name, learning_rate=lr)
426+
self.opt = optimizers.get(self.opt_name, learning_rate=lr, decay=decay)
422427
self.opt_state = self.opt.init(self.params)
423428

424429
@jax.jit

deepxde/optimizers/jax/optimizers.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,13 @@ def get(optimizer, learning_rate=None, decay=None):
3636
def _get_learningrate(lr, decay):
3737
if decay is None:
3838
return lr
39-
# TODO: add optax's optimizer schedule
39+
if decay[0] == "linear":
40+
return optax.linear_schedule(lr, decay[1], decay[2])
41+
elif decay[0] == "cosine":
42+
return optax.cosine_decay_schedule(lr, decay[1], decay[2])
43+
elif decay[0] == "exponential":
44+
return optax.exponential_decay(lr, decay[1], decay[2])
45+
4046
raise NotImplementedError(
4147
f"{decay[0]} learning rate decay to be implemented for backend jax."
4248
)

0 commit comments

Comments
 (0)