Skip to content

Commit a60cd74

Browse files
authored
Backend JAX: learning rate decay with Optax (#1992)
1 parent a68a0eb commit a60cd74

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

deepxde/model.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,18 @@ def compile(
107107
<https://www.paddlepaddle.org.cn/documentation/docs/en/develop/api/paddle/optimizer/lr/InverseTimeDecay_en.html>`_:
108108
("inverse time", gamma)
109109
110+
- For backend JAX:
111+
112+
- `linear_schedule
113+
<https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.schedules.linear_schedule>`_:
114+
("linear", end_value, transition_steps)
115+
- `cosine_decay_schedule
116+
<https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.schedules.cosine_decay_schedule>`_:
117+
("cosine", decay_steps, alpha)
118+
- `exponential_decay
119+
<https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.schedules.exponential_decay>`_:
120+
("exponential", transition_steps, decay_rate)
121+
110122
loss_weights: A list specifying scalar coefficients (Python floats) to
111123
weight the loss contributions. The loss value that will be minimized by
112124
the model will then be the weighted sum of all individual losses,
@@ -417,8 +429,7 @@ def _compile_jax(self, lr, loss_fn, decay):
417429
var.value for var in self.external_trainable_variables
418430
]
419431
self.params = [self.net.params, external_trainable_variables_val]
420-
# TODO: learning rate decay
421-
self.opt = optimizers.get(self.opt_name, learning_rate=lr)
432+
self.opt = optimizers.get(self.opt_name, learning_rate=lr, decay=decay)
422433
self.opt_state = self.opt.init(self.params)
423434

424435
@jax.jit

deepxde/optimizers/jax/optimizers.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,13 @@ def get(optimizer, learning_rate=None, decay=None):
3636
def _get_learningrate(lr, decay):
3737
if decay is None:
3838
return lr
39-
# TODO: add optax's optimizer schedule
39+
if decay[0] == "linear":
40+
return optax.linear_schedule(lr, decay[1], decay[2])
41+
if decay[0] == "cosine":
42+
return optax.cosine_decay_schedule(lr, decay[1], decay[2])
43+
if decay[0] == "exponential":
44+
return optax.exponential_decay(lr, decay[1], decay[2])
45+
4046
raise NotImplementedError(
4147
f"{decay[0]} learning rate decay to be implemented for backend jax."
4248
)

0 commit comments

Comments
 (0)