Skip to content

Commit 8275aeb

Browse files
Add NNCG to optimizers submodule (#1661)
1 parent bb1d3ac commit 8275aeb

File tree

7 files changed

+428
-7
lines changed

7 files changed

+428
-7
lines changed

deepxde/model.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,11 +367,22 @@ def closure():
367367
if self.lr_scheduler is not None:
368368
self.lr_scheduler.step()
369369

370+
def train_step_nncg(inputs, targets, auxiliary_vars):
371+
def closure():
372+
losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1]
373+
total_loss = torch.sum(losses)
374+
self.opt.zero_grad()
375+
return total_loss
376+
377+
self.opt.step(closure)
378+
if self.lr_scheduler is not None:
379+
self.lr_scheduler.step()
380+
370381
# Callables
371382
self.outputs = outputs
372383
self.outputs_losses_train = outputs_losses_train
373384
self.outputs_losses_test = outputs_losses_test
374-
self.train_step = train_step
385+
self.train_step = train_step if self.opt_name != "NNCG" else train_step_nncg
375386

376387
def _compile_jax(self, lr, loss_fn, decay):
377388
"""jax"""
@@ -652,7 +663,10 @@ def train(
652663
elif backend_name == "tensorflow":
653664
self._train_tensorflow_tfp(verbose=verbose)
654665
elif backend_name == "pytorch":
655-
self._train_pytorch_lbfgs(verbose=verbose)
666+
if self.opt_name == "L-BFGS":
667+
self._train_pytorch_lbfgs(verbose=verbose)
668+
elif self.opt_name == "NNCG":
669+
self._train_sgd(iterations, display_every, verbose=verbose)
656670
elif backend_name == "paddle":
657671
self._train_paddle_lbfgs(verbose=verbose)
658672
else:

deepxde/optimizers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import importlib
22
import sys
33

4-
from .config import LBFGS_options, set_LBFGS_options
4+
from .config import LBFGS_options, set_LBFGS_options, NNCG_options, set_NNCG_options
55
from ..backend import backend_name
66

77

deepxde/optimizers/config.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
__all__ = ["set_LBFGS_options", "set_hvd_opt_options"]
1+
__all__ = ["set_LBFGS_options", "set_NNCG_options", "set_hvd_opt_options"]
22

33
from ..backend import backend_name
44
from ..config import hvd
55

66
LBFGS_options = {}
7+
NNCG_options = {}
78
if hvd is not None:
89
hvd_opt_options = {}
910

@@ -60,6 +61,60 @@ def set_LBFGS_options(
6061
LBFGS_options["maxls"] = maxls
6162

6263

64+
def set_NNCG_options(
65+
lr=1,
66+
rank=50,
67+
mu=1e-1,
68+
updatefreq=20,
69+
chunksz=1,
70+
cgtol=1e-16,
71+
cgmaxiter=1000,
72+
lsfun="armijo",
73+
verbose=False,
74+
):
75+
"""Sets the hyperparameters of NysNewtonCG (NNCG).
76+
77+
The NNCG optimizer only supports PyTorch.
78+
79+
Args:
80+
lr (float):
81+
Learning rate (before line search).
82+
rank (int):
83+
Rank of preconditioner matrix used in preconditioned conjugate gradient.
84+
mu (float):
85+
Hessian damping parameter.
86+
updatefreq (int):
87+
How often the preconditioner matrix in preconditioned
88+
conjugate gradient is updated. This parameter is not directly used in NNCG,
89+
instead it is used in _train_pytorch_nncg in deepxde/model.py.
90+
chunksz (int):
91+
Number of Hessian-vector products to compute in parallel when constructing
92+
preconditioner. If `chunk_size` is 1, the Hessian-vector products are
93+
computed serially.
94+
cgtol (float):
95+
Convergence tolerance for the conjugate gradient method. The iteration stops
96+
when `||r||_2 <= cgtol`, where `r` is the residual. Note that this condition
97+
is based on the absolute tolerance, not the relative tolerance.
98+
cgmaxiter (int):
99+
Maximum number of iterations for the conjugate gradient method.
100+
lsfun (str):
101+
The line search function used to find the step size. The default value is
102+
"armijo". The other option is None.
103+
verbose (bool):
104+
If `True`, prints the eigenvalues of the Nyström approximation
105+
of the Hessian.
106+
"""
107+
NNCG_options["lr"] = lr
108+
NNCG_options["rank"] = rank
109+
NNCG_options["mu"] = mu
110+
NNCG_options["updatefreq"] = updatefreq
111+
NNCG_options["chunksz"] = chunksz
112+
NNCG_options["cgtol"] = cgtol
113+
NNCG_options["cgmaxiter"] = cgmaxiter
114+
NNCG_options["lsfun"] = lsfun
115+
NNCG_options["verbose"] = verbose
116+
117+
63118
def set_hvd_opt_options(
64119
compression=None,
65120
op=None,
@@ -91,6 +146,7 @@ def set_hvd_opt_options(
91146

92147

93148
set_LBFGS_options()
149+
set_NNCG_options()
94150
if hvd is not None:
95151
set_hvd_opt_options()
96152

0 commit comments

Comments
 (0)