diff --git a/deepxde/model.py b/deepxde/model.py index abe7d83ec..7eae10a6f 100644 --- a/deepxde/model.py +++ b/deepxde/model.py @@ -107,6 +107,18 @@ def compile( `_: ("inverse time", gamma) + - For backend JAX: + + - `linear_schedule + `_: + ("linear", end_value, transition_steps) + - `cosine_decay_schedule + `_: + ("cosine", decay_steps, alpha) + - `exponential_decay + `_: + ("exponential", transition_steps, decay_rate) + loss_weights: A list specifying scalar coefficients (Python floats) to weight the loss contributions. The loss value that will be minimized by the model will then be the weighted sum of all individual losses, @@ -417,8 +429,7 @@ def _compile_jax(self, lr, loss_fn, decay): var.value for var in self.external_trainable_variables ] self.params = [self.net.params, external_trainable_variables_val] - # TODO: learning rate decay - self.opt = optimizers.get(self.opt_name, learning_rate=lr) + self.opt = optimizers.get(self.opt_name, learning_rate=lr, decay=decay) self.opt_state = self.opt.init(self.params) @jax.jit diff --git a/deepxde/optimizers/jax/optimizers.py b/deepxde/optimizers/jax/optimizers.py index 8df252cfd..461eb2ffb 100644 --- a/deepxde/optimizers/jax/optimizers.py +++ b/deepxde/optimizers/jax/optimizers.py @@ -36,7 +36,13 @@ def get(optimizer, learning_rate=None, decay=None): def _get_learningrate(lr, decay): if decay is None: return lr - # TODO: add optax's optimizer schedule + if decay[0] == "linear": + return optax.linear_schedule(lr, decay[1], decay[2]) + if decay[0] == "cosine": + return optax.cosine_decay_schedule(lr, decay[1], decay[2]) + if decay[0] == "exponential": + return optax.exponential_decay(lr, decay[1], decay[2]) + raise NotImplementedError( f"{decay[0]} learning rate decay to be implemented for backend jax." )