@@ -107,6 +107,18 @@ def compile(
107107 <https://www.paddlepaddle.org.cn/documentation/docs/en/develop/api/paddle/optimizer/lr/InverseTimeDecay_en.html>`_:
108108 ("inverse time", gamma)
109109
110+ - For backend JAX:
111+
112+ - `linear_schedule
113+ <https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.schedules.linear_schedule>`_:
114+ ("linear", end_value, transition_steps)
115+ - `cosine_decay_schedule
116+ <https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.schedules.cosine_decay_schedule>`_:
117+ ("cosine", decay_steps, alpha)
118+ - `exponential_decay
119+ <https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.schedules.exponential_decay>`_:
120+ ("exponential", transition_steps, decay_rate)
121+
110122 loss_weights: A list specifying scalar coefficients (Python floats) to
111123 weight the loss contributions. The loss value that will be minimized by
112124 the model will then be the weighted sum of all individual losses,
@@ -417,8 +429,7 @@ def _compile_jax(self, lr, loss_fn, decay):
417429 var .value for var in self .external_trainable_variables
418430 ]
419431 self .params = [self .net .params , external_trainable_variables_val ]
420- # TODO: learning rate decay
421- self .opt = optimizers .get (self .opt_name , learning_rate = lr )
432+ self .opt = optimizers .get (self .opt_name , learning_rate = lr , decay = decay )
422433 self .opt_state = self .opt .init (self .params )
423434
424435 @jax .jit
0 commit comments