Skip to content

Commit ce1dcc8

Browse files
format with black
1 parent 5c9beba commit ce1dcc8

File tree

4 files changed

+20
-11
lines changed

4 files changed

+20
-11
lines changed

deepxde/model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -658,8 +658,10 @@ def train(
658658
elif self.opt_name == "NNCG":
659659
self._train_pytorch_nncg(iterations, display_every)
660660
else:
661-
raise ValueError("Only 'L-BFGS' and 'NNCG' are supported as \
662-
external optimizers for PyTorch.")
661+
raise ValueError(
662+
"Only 'L-BFGS' and 'NNCG' are supported as \
663+
external optimizers for PyTorch."
664+
)
663665
elif backend_name == "paddle":
664666
self._train_paddle_lbfgs()
665667
else:
@@ -827,7 +829,6 @@ def _train_pytorch_nncg(self, iterations, display_every):
827829
if self.stop_training:
828830
break
829831

830-
831832
def _train_paddle_lbfgs(self):
832833
prev_n_iter = 0
833834

deepxde/optimizers/config.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def set_LBFGS_options(
6060
LBFGS_options["maxfun"] = maxfun if maxfun is not None else int(maxiter * 1.25)
6161
LBFGS_options["maxls"] = maxls
6262

63+
6364
def set_NNCG_options(
6465
lr=1,
6566
rank=10,
@@ -69,7 +70,7 @@ def set_NNCG_options(
6970
cgtol=1e-16,
7071
cgmaxiter=1000,
7172
lsfun="armijo",
72-
verbose=False
73+
verbose=False,
7374
):
7475
"""Sets the hyperparameters of NysNewtonCG (NNCG).
7576
@@ -80,11 +81,11 @@ def set_NNCG_options(
8081
Rank of preconditioner matrix used in preconditioned conjugate gradient.
8182
mu (float): `mu` (torch).
8283
Hessian damping parameter.
83-
updatefreq (int): How often the preconditioner matrix in preconditioned
84+
updatefreq (int): How often the preconditioner matrix in preconditioned
8485
conjugate gradient is updated. This parameter is not directly used in NNCG,
8586
instead it is used in _train_pytorch_nncg in deepxde/model.py.
8687
chunksz (int): `chunk_size` (torch).
87-
Number of Hessian-vector products to compute in parallel when constructing
88+
Number of Hessian-vector products to compute in parallel when constructing
8889
preconditioner. If `chunk_size` is 1, the Hessian-vector products are
8990
computed serially.
9091
cgtol (float): `cg_tol` (torch).
@@ -110,6 +111,7 @@ def set_NNCG_options(
110111
NNCG_options["lsfun"] = lsfun
111112
NNCG_options["verbose"] = verbose
112113

114+
113115
def set_hvd_opt_options(
114116
compression=None,
115117
op=None,

deepxde/optimizers/pytorch/nncg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class NNCG(Optimizer):
8181
8282
The parameters rank and mu will probably need to be tuned for your specific problem.
8383
If the optimizer is running very slowly, you can try one of the following:
84-
- Increase the rank (this should increase the
84+
- Increase the rank (this should increase the
8585
accuracy of the Nyström approximation in PCG)
8686
- Reduce cg_tol (this will allow PCG to terminate with a less accurate solution)
8787
- Reduce cg_max_iters (this will allow PCG to terminate after fewer iterations)
@@ -155,7 +155,7 @@ def step(self, closure):
155155
156156
Args:
157157
closure (callable): A closure that reevaluates the model
158-
and returns the loss w.r.t. the parameters.
158+
and returns the loss w.r.t. the parameters.
159159
"""
160160
if self.n_iters == 0:
161161
# Store the previous direction for warm starting PCG

examples/pinn_forward/Burgers_NNCG.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch, paddle"""
2+
23
import deepxde as dde
34
import numpy as np
45

@@ -49,7 +50,10 @@ def pde(x, y):
4950
y_pred = model.predict(X)
5051
f = model.predict(X, operator=pde)
5152
print("Mean residual after Adam+L-BFGS:", np.mean(np.absolute(f)))
52-
print("L2 relative error after Adam+L-BFGS:", dde.metrics.l2_relative_error(y_true, y_pred))
53+
print(
54+
"L2 relative error after Adam+L-BFGS:",
55+
dde.metrics.l2_relative_error(y_true, y_pred),
56+
)
5357
np.savetxt("test_adam_lbfgs.dat", np.hstack((X, y_true, y_pred)))
5458

5559
# Run NNCG after Adam+L-BFGS
@@ -62,6 +66,8 @@ def pde(x, y):
6266
y_pred = model.predict(X)
6367
f = model.predict(X, operator=pde)
6468
print("Mean residual after Adam+L-BFGS+NNCG:", np.mean(np.absolute(f)))
65-
print("L2 relative error after Adam+L-BFGS+NNCG:",
66-
dde.metrics.l2_relative_error(y_true, y_pred))
69+
print(
70+
"L2 relative error after Adam+L-BFGS+NNCG:",
71+
dde.metrics.l2_relative_error(y_true, y_pred),
72+
)
6773
np.savetxt("test_adam_lbfgs_nncg.dat", np.hstack((X, y_true, y_pred)))

0 commit comments

Comments
 (0)