Skip to content

Commit af0083b

Browse files
authored
Merge pull request #31 from lululxvi/master
latest update
2 parents cd06627 + 99c3626 commit af0083b

File tree

13 files changed

+357
-57
lines changed

13 files changed

+357
-57
lines changed

deepxde/backend/jax/tensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,10 @@ def reduce_max(input_tensor):
165165
return jnp.max(input_tensor)
166166

167167

168+
def norm(tensor, ord=None, axis=None, keepdims=False):
169+
return jnp.linalg.norm(tensor, ord=ord, axis=axis, keepdims=keepdims)
170+
171+
168172
def zeros(shape, dtype):
169173
return jnp.zeros(shape, dtype=dtype)
170174

deepxde/backend/paddle/__init__.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,17 @@
1-
from .tensor import * # pylint: disable=redefined-builtin
1+
import os
2+
3+
from .tensor import * # pylint: disable=redefined-builtin
4+
5+
# enable prim if specified
6+
enable_prim_value = os.getenv("PRIM")
7+
enable_prim = enable_prim_value.lower() in ['1', 'true', 'yes', 'on'] if enable_prim_value else False
8+
if enable_prim:
9+
# Mostly for compiler running with dy2st.
10+
from paddle.framework import core
11+
12+
core.set_prim_eager_enabled(True)
13+
# The following protected member access is required.
14+
# There is no alternative public API available now.
15+
# pylint: disable=protected-access
16+
core._set_prim_all_enabled(True)
17+
print("Prim mode is enabled.")

deepxde/config.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040

4141
# Default float type
4242
real = Real(32)
43+
# Using mixed precision
44+
mixed = False
4345
# Random seed
4446
random_seed = None
4547
if backend_name == "jax":
@@ -71,11 +73,14 @@ def default_float():
7173
def set_default_float(value):
7274
"""Sets the default float type.
7375
74-
The default floating point type is 'float32'.
76+
The default floating point type is 'float32'. Mixed precision uses the method in the paper:
77+
`J. Hayford, J. Goldman-Wetzler, E. Wang, & L. Lu. Speeding up and reducing memory usage for scientific machine learning via mixed precision.
78+
Computer Methods in Applied Mechanics and Engineering, 428, 117093, 2024 <https://doi.org/10.1016/j.cma.2024.117093>`_.
7579
7680
Args:
77-
value (String): 'float16', 'float32', or 'float64'.
81+
value (String): 'float16', 'float32', 'float64', or 'mixed' (mixed precision).
7882
"""
83+
global mixed
7984
if value == "float16":
8085
print("Set the default float type to float16")
8186
real.set_float16()
@@ -85,6 +90,20 @@ def set_default_float(value):
8590
elif value == "float64":
8691
print("Set the default float type to float64")
8792
real.set_float64()
93+
elif value == "mixed":
94+
print("Set the float type to mixed precision of float16 and float32")
95+
mixed = True
96+
if backend_name == "tensorflow":
97+
real.set_float16()
98+
tf.keras.mixed_precision.set_global_policy("mixed_float16")
99+
return # don't try to set it again below
100+
if backend_name == "pytorch":
101+
# Use float16 during the forward and backward passes, but store in float32
102+
real.set_float32()
103+
else:
104+
raise ValueError(
105+
f"{backend_name} backend does not currently support mixed precision."
106+
)
88107
else:
89108
raise ValueError(f"{value} not supported in deepXDE")
90109
if backend_name in ["tensorflow.compat.v1", "tensorflow"]:

deepxde/data/mf.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import numpy as np
22

33
from .data import Data
4-
from ..backend import tf
4+
from .. import backend as bkd
5+
from .. import config
56
from ..utils import run_if_any_none, standardize
67

78

@@ -83,20 +84,20 @@ def __init__(
8384
standardize=False,
8485
):
8586
if X_lo_train is not None:
86-
self.X_lo_train = X_lo_train
87-
self.X_hi_train = X_hi_train
88-
self.y_lo_train = y_lo_train
89-
self.y_hi_train = y_hi_train
90-
self.X_hi_test = X_hi_test
91-
self.y_hi_test = y_hi_test
87+
self.X_lo_train = X_lo_train.astype(config.real(np))
88+
self.X_hi_train = X_hi_train.astype(config.real(np))
89+
self.y_lo_train = y_lo_train.astype(config.real(np))
90+
self.y_hi_train = y_hi_train.astype(config.real(np))
91+
self.X_hi_test = X_hi_test.astype(config.real(np))
92+
self.y_hi_test = y_hi_test.astype(config.real(np))
9293
elif fname_lo_train is not None:
93-
data = np.loadtxt(fname_lo_train)
94+
data = np.loadtxt(fname_lo_train).astype(config.real(np))
9495
self.X_lo_train = data[:, col_x]
9596
self.y_lo_train = data[:, col_y]
96-
data = np.loadtxt(fname_hi_train)
97+
data = np.loadtxt(fname_hi_train).astype(config.real(np))
9798
self.X_hi_train = data[:, col_x]
9899
self.y_hi_train = data[:, col_y]
99-
data = np.loadtxt(fname_hi_test)
100+
data = np.loadtxt(fname_hi_test).astype(config.real(np))
100101
self.X_hi_test = data[:, col_x]
101102
self.y_hi_test = data[:, col_y]
102103
else:
@@ -116,7 +117,10 @@ def losses_train(self, targets, outputs, loss_fn, inputs, model, aux=None):
116117
return [loss_lo, loss_hi]
117118

118119
def losses_test(self, targets, outputs, loss_fn, inputs, model, aux=None):
119-
return [0, loss_fn(targets[1], outputs[1])]
120+
return [
121+
bkd.as_tensor(0, dtype=config.real(bkd.lib)),
122+
loss_fn(targets[1], outputs[1]),
123+
]
120124

121125
@run_if_any_none("X_train", "y_train")
122126
def train_next_batch(self, batch_size=None):

deepxde/model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,11 @@ def closure():
374374
total_loss.backward()
375375
return total_loss
376376

377-
self.opt.step(closure)
377+
def closure_mixed():
378+
with torch.autocast(device_type=torch.get_default_device().type, dtype=torch.float16):
379+
return closure()
380+
381+
self.opt.step(closure if not config.mixed else closure_mixed)
378382
if self.lr_scheduler is not None:
379383
self.lr_scheduler.step()
380384

deepxde/nn/paddle/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
"DeepONet",
55
"DeepONetCartesianProd",
66
"FNN",
7+
"MfNN",
78
"MsFFN",
89
"PFNN",
910
"STMsFFN",
1011
]
1112

1213
from .deeponet import DeepONet, DeepONetCartesianProd
1314
from .fnn import FNN, PFNN
15+
from .mfnn import MfNN
1416
from .msffn import MsFFN, STMsFFN

deepxde/nn/paddle/mfnn.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import paddle
2+
3+
from .nn import NN
4+
from .. import activations
5+
from .. import initializers
6+
from .. import regularizers
7+
from ... import config
8+
9+
10+
class MfNN(NN):
11+
"""Multifidelity neural networks."""
12+
13+
def __init__(
14+
self,
15+
layer_sizes_low_fidelity,
16+
layer_sizes_high_fidelity,
17+
activation,
18+
kernel_initializer,
19+
regularization=None,
20+
residue=False,
21+
trainable_low_fidelity=True,
22+
trainable_high_fidelity=True,
23+
):
24+
super().__init__()
25+
self.layer_size_lo = layer_sizes_low_fidelity
26+
self.layer_size_hi = layer_sizes_high_fidelity
27+
28+
self.activation = activations.get(activation)
29+
self.initializer = initializers.get(kernel_initializer)
30+
self.trainable_lo = trainable_low_fidelity
31+
self.trainable_hi = trainable_high_fidelity
32+
self.residue = residue
33+
self.regularizer = regularizers.get(regularization)
34+
35+
# low fidelity
36+
self.linears_lo = self._init_dense(self.layer_size_lo, self.trainable_lo)
37+
38+
# high fidelity
39+
# linear part
40+
self.linears_hi_l = paddle.nn.Linear(
41+
in_features=self.layer_size_lo[0] + self.layer_size_lo[-1],
42+
out_features=self.layer_size_hi[-1],
43+
weight_attr=paddle.ParamAttr(initializer=self.initializer),
44+
)
45+
if not self.trainable_hi:
46+
for param in self.linears_hi_l.parameters():
47+
param.stop_gradient = False
48+
# nonlinear part
49+
self.layer_size_hi = [
50+
self.layer_size_lo[0] + self.layer_size_lo[-1]
51+
] + self.layer_size_hi
52+
self.linears_hi = self._init_dense(self.layer_size_hi, self.trainable_hi)
53+
# linear + nonlinear
54+
if not self.residue:
55+
alpha = self._init_alpha(0.0, self.trainable_hi)
56+
self.add_parameter("alpha", alpha)
57+
else:
58+
alpha1 = self._init_alpha(0.0, self.trainable_hi)
59+
alpha2 = self._init_alpha(0.0, self.trainable_hi)
60+
self.add_parameter("alpha1", alpha1)
61+
self.add_parameter("alpha2", alpha2)
62+
63+
def _init_dense(self, layer_size, trainable):
64+
linears = paddle.nn.LayerList()
65+
for i in range(len(layer_size) - 1):
66+
linear = paddle.nn.Linear(
67+
in_features=layer_size[i],
68+
out_features=layer_size[i + 1],
69+
weight_attr=paddle.ParamAttr(initializer=self.initializer),
70+
)
71+
if not trainable:
72+
for param in linear.parameters():
73+
param.stop_gradient = False
74+
linears.append(linear)
75+
return linears
76+
77+
def _init_alpha(self, value, trainable):
78+
alpha = paddle.create_parameter(
79+
shape=[1],
80+
dtype=config.real(paddle),
81+
default_initializer=paddle.nn.initializer.Constant(value),
82+
)
83+
alpha.stop_gradient = not trainable
84+
return alpha
85+
86+
def forward(self, inputs):
87+
# low fidelity
88+
y = inputs
89+
for i, linear in enumerate(self.linears_lo):
90+
y = linear(y)
91+
if i != len(self.linears_lo) - 1:
92+
y = self.activation(y)
93+
y_lo = y
94+
95+
# high fidelity
96+
x_hi = paddle.concat([inputs, y_lo], axis=1)
97+
# linear
98+
y_hi_l = self.linears_hi_l(x_hi)
99+
# nonlinear
100+
y = x_hi
101+
for i, linear in enumerate(self.linears_hi):
102+
y = linear(y)
103+
if i != len(self.linears_hi) - 1:
104+
y = self.activation(y)
105+
y_hi_nl = y
106+
# linear + nonlinear
107+
if not self.residue:
108+
alpha = paddle.tanh(self.alpha)
109+
y_hi = y_hi_l + alpha * y_hi_nl
110+
else:
111+
alpha1 = paddle.tanh(self.alpha1)
112+
alpha2 = paddle.tanh(self.alpha2)
113+
y_hi = y_lo + 0.1 * (alpha1 * y_hi_l + alpha2 * y_hi_nl)
114+
115+
return y_lo, y_hi

deepxde/zcs/operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
class PDEOperatorCartesianProd(BasePDEOperatorCartesianProd):
1010
"""Derived `PDEOperatorCartesianProd` class for ZCS support."""
1111

12-
def _losses(self, outputs, loss_fn, inputs, model, num_func):
12+
def _losses(self, outputs, loss_fn, inputs, model, num_func, aux):
1313
# PDE
1414
f = []
1515
if self.pde.pde is not None:

docs/user/faq.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ General usage
1010
| **A**: `#5`_
1111
- | **Q**: By default, DeepXDE uses ``float32``. How can I use ``float64``?
1212
| **A**: `#28`_
13+
- | **Q**: How can I use mixed precision training?
14+
| **A**: Use ``dde.config.set_default_float("mixed")`` with the ``tensorflow`` or ``pytorch`` backends. See `this paper <https://doi.org/10.1016/j.cma.2024.117093>`_ for more information.
1315
- | **Q**: I want to set the global random seeds.
1416
| **A**: `#353`_
1517
- | **Q**: GPU.

0 commit comments

Comments
 (0)