Skip to content

Commit 8bc9141

Browse files
Backend Paddle: Add LBFGS optimizer (#1095)
1 parent fc978ad commit 8bc9141

File tree

9 files changed

+78
-24
lines changed

9 files changed

+78
-24
lines changed

deepxde/model.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -472,11 +472,24 @@ def train_step(inputs, targets, auxiliary_vars):
472472
if self.lr_scheduler is not None:
473473
self.lr_scheduler.step()
474474

475+
def train_step_lbfgs(inputs, targets, auxiliary_vars):
476+
def closure():
477+
losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1]
478+
total_loss = paddle.sum(losses)
479+
self.opt.clear_grad()
480+
total_loss.backward()
481+
return total_loss
482+
483+
self.opt.step(closure)
475484
# Callables
476485
self.outputs = outputs
477486
self.outputs_losses_train = outputs_losses_train
478487
self.outputs_losses_test = outputs_losses_test
479-
self.train_step = train_step
488+
self.train_step = (
489+
train_step
490+
if not optimizers.is_external_optimizer(self.opt_name)
491+
else train_step_lbfgs
492+
)
480493

481494
def _outputs(self, training, inputs):
482495
if backend_name == "tensorflow.compat.v1":
@@ -599,7 +612,7 @@ def train(
599612
elif backend_name == "pytorch":
600613
self._train_pytorch_lbfgs()
601614
elif backend_name == "paddle":
602-
raise NotImplementedError("L-BFGS will be implemented soon in PaddlePaddle")
615+
self._train_paddle_lbfgs()
603616
else:
604617
if iterations is None:
605618
raise ValueError("No iterations for {}.".format(self.opt_name))
@@ -740,6 +753,38 @@ def _train_pytorch_lbfgs(self):
740753
if self.stop_training:
741754
break
742755

756+
def _train_paddle_lbfgs(self):
757+
prev_n_iter = 0
758+
759+
while prev_n_iter < optimizers.LBFGS_options["maxiter"]:
760+
self.callbacks.on_epoch_begin()
761+
self.callbacks.on_batch_begin()
762+
763+
self.train_state.set_data_train(
764+
*self.data.train_next_batch(self.batch_size)
765+
)
766+
self._train_step(
767+
self.train_state.X_train,
768+
self.train_state.y_train,
769+
self.train_state.train_aux_vars,
770+
)
771+
772+
n_iter = self.opt.state_dict()["state"]["n_iter"]
773+
if prev_n_iter == n_iter:
774+
# Converged
775+
break
776+
777+
self.train_state.epoch += n_iter - prev_n_iter
778+
self.train_state.step += n_iter - prev_n_iter
779+
prev_n_iter = n_iter
780+
self._test()
781+
782+
self.callbacks.on_batch_end()
783+
self.callbacks.on_epoch_end()
784+
785+
if self.stop_training:
786+
break
787+
743788
def _test(self):
744789
(
745790
self.train_state.y_pred_train,

deepxde/optimizers/config.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,24 @@ def set_LBFGS_options(
2121
- TensorFlow 1.x: `scipy.optimize.minimize <https://docs.scipy.org/doc/scipy/reference/optimize.minimize-lbfgsb.html#optimize-minimize-lbfgsb>`_
2222
- TensorFlow 2.x: `tfp.optimizer.lbfgs_minimize <https://www.tensorflow.org/probability/api_docs/python/tfp/optimizer/lbfgs_minimize>`_
2323
- PyTorch: `torch.optim.LBFGS <https://pytorch.org/docs/stable/generated/torch.optim.LBFGS.html>`_
24+
- Paddle: `paddle.incubate.optimizers.LBFGS <https://www.paddlepaddle.org.cn/documentation/docs/en/develop/api/paddle/incubate/optimizer/LBFGS_en.html>`_
2425
2526
I find empirically that torch.optim.LBFGS and scipy.optimize.minimize are better than
2627
tfp.optimizer.lbfgs_minimize in terms of the final loss value.
2728
2829
Args:
29-
maxcor (int): `maxcor` (scipy), `num_correction_pairs` (tfp), `history_size` (torch).
30+
maxcor (int): `maxcor` (scipy), `num_correction_pairs` (tfp), `history_size` (torch), `history_size` (paddle).
3031
The maximum number of variable metric corrections used to define the limited
3132
memory matrix. (The limited memory BFGS method does not store the full
3233
hessian but uses this many terms in an approximation to it.)
33-
ftol (float): `ftol` (scipy), `f_relative_tolerance` (tfp), `tolerance_change` (torch).
34+
ftol (float): `ftol` (scipy), `f_relative_tolerance` (tfp), `tolerance_change` (torch), `tolerance_change` (paddle).
3435
The iteration stops when `(f^k - f^{k+1})/max{|f^k|,|f^{k+1}|,1} <= ftol`.
35-
gtol (float): `gtol` (scipy), `tolerance` (tfp), `tolerance_grad` (torch).
36+
gtol (float): `gtol` (scipy), `tolerance` (tfp), `tolerance_grad` (torch), `tolerance_grad` (paddle).
3637
The iteration will stop when `max{|proj g_i | i = 1, ..., n} <= gtol` where
3738
`pg_i` is the i-th component of the projected gradient.
38-
maxiter (int): `maxiter` (scipy), `max_iterations` (tfp), `max_iter` (torch).
39+
maxiter (int): `maxiter` (scipy), `max_iterations` (tfp), `max_iter` (torch), `max_iter` (paddle).
3940
Maximum number of iterations.
40-
maxfun (int): `maxfun` (scipy), `max_eval` (torch).
41+
maxfun (int): `maxfun` (scipy), `max_eval` (torch), `max_eval` (paddle).
4142
Maximum number of function evaluations. If ``None``, `maxiter` * 1.25.
4243
maxls (int): `maxls` (scipy), `max_line_search_iterations` (tfp).
4344
Maximum number of line search steps (per iteration).
@@ -62,7 +63,7 @@ def set_LBFGS_options(
6263

6364

6465
# Backend-dependent options
65-
if backend_name == "pytorch":
66+
if backend_name in ["pytorch", "paddle"]:
6667
# number of iterations per optimization call
6768
LBFGS_options["iter_per_step"] = min(1000, LBFGS_options["maxiter"])
6869
LBFGS_options["fun_per_step"] = (

deepxde/optimizers/paddle/optimizers.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
__all__ = ["get", "is_external_optimizer"]
22

33
import paddle
4+
from paddle.incubate.optimizer import LBFGS
5+
6+
from ..config import LBFGS_options
47

58

69
def _get_lr_scheduler(lr, decay):
@@ -22,10 +25,21 @@ def get(params, optimizer, learning_rate=None, decay=None):
2225
if isinstance(optimizer, paddle.optimizer.Optimizer):
2326
return optimizer
2427

25-
if is_external_optimizer(optimizer):
26-
# TODO: add support for L-BFGS and L-BFGS-B
27-
raise NotImplementedError(f"{optimizer} is not implemented in PaddlePaddle")
28-
28+
if optimizer in ["L-BFGS", "L-BFGS-B"]:
29+
if learning_rate is not None or decay is not None:
30+
print("Warning: learning rate is ignored for {}".format(optimizer))
31+
optim = LBFGS(
32+
lr=1,
33+
max_iter=LBFGS_options["iter_per_step"],
34+
max_eval=LBFGS_options["fun_per_step"],
35+
tolerance_grad=LBFGS_options["gtol"],
36+
tolerance_change=LBFGS_options["ftol"],
37+
history_size=LBFGS_options["maxcor"],
38+
line_search_fn='strong_wolfe',
39+
parameters=params,
40+
)
41+
return optim
42+
2943
if learning_rate is None:
3044
raise ValueError("No learning rate for {}.".format(optimizer))
3145

@@ -34,4 +48,4 @@ def get(params, optimizer, learning_rate=None, decay=None):
3448

3549
if optimizer == "adam":
3650
return paddle.optimizer.Adam(learning_rate=learning_rate, parameters=params)
37-
raise NotImplementedError(f"{optimizer} is not implemented in PaddlePaddle")
51+
raise NotImplementedError(f"{optimizer} to be implemented for backend Paddle.")

examples/pinn_forward/Beltrami_flow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch"""
1+
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch, paddle"""
22
import deepxde as dde
33
import numpy as np
44

examples/pinn_forward/Burgers_RAR.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch"""
1+
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch, paddle"""
22
import deepxde as dde
33
import numpy as np
44

examples/pinn_forward/Kovasznay_flow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch"""
1+
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch, paddle"""
22
import deepxde as dde
33
import numpy as np
44

examples/pinn_forward/Lotka_Volterra.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,7 @@ def input_transform(t):
7777
# def input_transform(t):
7878
# return paddle.concat(
7979
# (
80-
# t,
8180
# paddle.sin(t),
82-
# paddle.sin(2 * t),
83-
# paddle.sin(3 * t),
84-
# paddle.sin(4 * t),
85-
# paddle.sin(5 * t),
86-
# paddle.sin(6 * t),
8781
# ),
8882
# axis=1,
8983
# )

examples/pinn_forward/Poisson_Lshape.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch, jax"""
1+
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch, jax, paddle"""
22
import deepxde as dde
33

44

examples/pinn_forward/heat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch"""
1+
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch, paddle"""
22
import deepxde as dde
33
import numpy as np
44

0 commit comments

Comments
 (0)