Skip to content

Commit d1df15a

Browse files
authored
Merge branch 'lululxvi:master' into fix-variable-compile
2 parents d4ed4b6 + 18400e5 commit d1df15a

File tree

21 files changed

+457
-108
lines changed

21 files changed

+457
-108
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
strategy:
2020
fail-fast: false
2121
matrix:
22-
python-version: ["3.9", "3.10", "3.11"]
22+
python-version: ["3.9", "3.10", "3.11", "3.12"]
2323
os: [ubuntu-latest, macos-latest, windows-latest]
2424

2525
steps:

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ DeepXDE requires one of the following backend-specific dependencies to be instal
7474

7575
- TensorFlow 1.x: [TensorFlow](https://www.tensorflow.org)>=2.7.0
7676
- TensorFlow 2.x: [TensorFlow](https://www.tensorflow.org)>=2.3.0, [TensorFlow Probability](https://www.tensorflow.org/probability)>=0.11.0
77-
- PyTorch: [PyTorch](https://pytorch.org)>=1.9.0
77+
- PyTorch: [PyTorch](https://pytorch.org)>=2.0.0
7878
- JAX: [JAX](https://jax.readthedocs.io), [Flax](https://flax.readthedocs.io), [Optax](https://optax.readthedocs.io)
7979
- PaddlePaddle: [PaddlePaddle](https://www.paddlepaddle.org.cn/en)>=2.6.0
8080

deepxde/backend/pytorch/tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import torch
55

66

7-
if Version(torch.__version__) < Version("1.9.0"):
8-
raise RuntimeError("DeepXDE requires PyTorch>=1.9.0.")
7+
if Version(torch.__version__) < Version("2.0.0"):
8+
raise RuntimeError("DeepXDE requires PyTorch>=2.0.0.")
99

1010
# To write device-agnostic (CPU or GPU) code, a common pattern is to first determine
1111
# torch.device and then use it for all the tensors.

deepxde/data/mf.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import numpy as np
22

33
from .data import Data
4-
from ..backend import tf
4+
from .. import backend as bkd
5+
from .. import config
56
from ..utils import run_if_any_none, standardize
67

78

@@ -83,20 +84,20 @@ def __init__(
8384
standardize=False,
8485
):
8586
if X_lo_train is not None:
86-
self.X_lo_train = X_lo_train
87-
self.X_hi_train = X_hi_train
88-
self.y_lo_train = y_lo_train
89-
self.y_hi_train = y_hi_train
90-
self.X_hi_test = X_hi_test
91-
self.y_hi_test = y_hi_test
87+
self.X_lo_train = X_lo_train.astype(config.real(np))
88+
self.X_hi_train = X_hi_train.astype(config.real(np))
89+
self.y_lo_train = y_lo_train.astype(config.real(np))
90+
self.y_hi_train = y_hi_train.astype(config.real(np))
91+
self.X_hi_test = X_hi_test.astype(config.real(np))
92+
self.y_hi_test = y_hi_test.astype(config.real(np))
9293
elif fname_lo_train is not None:
93-
data = np.loadtxt(fname_lo_train)
94+
data = np.loadtxt(fname_lo_train).astype(config.real(np))
9495
self.X_lo_train = data[:, col_x]
9596
self.y_lo_train = data[:, col_y]
96-
data = np.loadtxt(fname_hi_train)
97+
data = np.loadtxt(fname_hi_train).astype(config.real(np))
9798
self.X_hi_train = data[:, col_x]
9899
self.y_hi_train = data[:, col_y]
99-
data = np.loadtxt(fname_hi_test)
100+
data = np.loadtxt(fname_hi_test).astype(config.real(np))
100101
self.X_hi_test = data[:, col_x]
101102
self.y_hi_test = data[:, col_y]
102103
else:
@@ -116,7 +117,10 @@ def losses_train(self, targets, outputs, loss_fn, inputs, model, aux=None):
116117
return [loss_lo, loss_hi]
117118

118119
def losses_test(self, targets, outputs, loss_fn, inputs, model, aux=None):
119-
return [0, loss_fn(targets[1], outputs[1])]
120+
return [
121+
bkd.as_tensor(0, dtype=config.real(bkd.lib)),
122+
loss_fn(targets[1], outputs[1]),
123+
]
120124

121125
@run_if_any_none("X_train", "y_train")
122126
def train_next_batch(self, batch_size=None):

deepxde/data/pde_operator.py

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -237,23 +237,59 @@ def __init__(
237237
self.train_next_batch()
238238
self.test()
239239

240-
def _losses(self, outputs, loss_fn, inputs, model, num_func):
240+
def _losses(self, outputs, loss_fn, inputs, model, num_func, aux=None):
241241
bcs_start = np.cumsum([0] + self.pde.num_bcs)
242242

243243
losses = []
244-
for i in range(num_func):
245-
out = outputs[i]
246-
# Single output
247-
if bkd.ndim(out) == 1:
248-
out = out[:, None]
244+
# PDE loss
245+
if config.autodiff == "reverse": # reverse mode AD
246+
for i in range(num_func):
247+
out = outputs[i]
248+
# Single output
249+
if bkd.ndim(out) == 1:
250+
out = out[:, None]
251+
f = []
252+
if self.pde.pde is not None:
253+
f = self.pde.pde(
254+
inputs[1], out, model.net.auxiliary_vars[i][:, None]
255+
)
256+
if not isinstance(f, (list, tuple)):
257+
f = [f]
258+
error_f = [fi[bcs_start[-1] :] for fi in f]
259+
losses_i = [loss_fn(bkd.zeros_like(error), error) for error in error_f]
260+
losses.append(losses_i)
261+
262+
losses = zip(*losses)
263+
# Use stack instead of as_tensor to keep the gradients.
264+
losses = [bkd.reduce_mean(bkd.stack(loss, 0)) for loss in losses]
265+
elif config.autodiff == "forward": # forward mode AD
266+
267+
def forward_call(trunk_input):
268+
return aux[0]((inputs[0], trunk_input))
269+
249270
f = []
250271
if self.pde.pde is not None:
251-
f = self.pde.pde(inputs[1], out, model.net.auxiliary_vars[i][:, None])
272+
# Each f has the shape (N1, N2)
273+
f = self.pde.pde(
274+
inputs[1], (outputs, forward_call), model.net.auxiliary_vars
275+
)
252276
if not isinstance(f, (list, tuple)):
253277
f = [f]
254-
error_f = [fi[bcs_start[-1] :] for fi in f]
255-
losses_i = [loss_fn(bkd.zeros_like(error), error) for error in error_f]
256-
278+
# Each error has the shape (N1, ~N2)
279+
error_f = [fi[:, bcs_start[-1] :] for fi in f]
280+
for error in error_f:
281+
error_i = []
282+
for i in range(num_func):
283+
error_i.append(loss_fn(bkd.zeros_like(error[i]), error[i]))
284+
losses.append(bkd.reduce_mean(bkd.stack(error_i, 0)))
285+
286+
# BC loss
287+
losses_bc = []
288+
for i in range(num_func):
289+
losses_i = []
290+
out = outputs[i]
291+
if bkd.ndim(out) == 1:
292+
out = out[:, None]
257293
for j, bc in enumerate(self.pde.bcs):
258294
beg, end = bcs_start[j], bcs_start[j + 1]
259295
# The same BC points are used for training and testing.
@@ -267,19 +303,21 @@ def _losses(self, outputs, loss_fn, inputs, model, num_func):
267303
)
268304
losses_i.append(loss_fn(bkd.zeros_like(error), error))
269305

270-
losses.append(losses_i)
306+
losses_bc.append(losses_i)
271307

272-
losses = zip(*losses)
273-
# Use stack instead of as_tensor to keep the gradients.
274-
losses = [bkd.reduce_mean(bkd.stack(loss, 0)) for loss in losses]
308+
losses_bc = zip(*losses_bc)
309+
losses_bc = [bkd.reduce_mean(bkd.stack(loss, 0)) for loss in losses_bc]
310+
losses.extend(losses_bc)
275311
return losses
276312

277313
def losses_train(self, targets, outputs, loss_fn, inputs, model, aux=None):
278314
num_func = self.num_func if self.batch_size is None else self.batch_size
279-
return self._losses(outputs, loss_fn, inputs, model, num_func)
315+
return self._losses(outputs, loss_fn, inputs, model, num_func, aux=aux)
280316

281317
def losses_test(self, targets, outputs, loss_fn, inputs, model, aux=None):
282-
return self._losses(outputs, loss_fn, inputs, model, len(self.test_x[0]))
318+
return self._losses(
319+
outputs, loss_fn, inputs, model, len(self.test_x[0]), aux=aux
320+
)
283321

284322
def train_next_batch(self, batch_size=None):
285323
if self.train_x is None:

deepxde/model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,11 @@ def outputs_losses_test(inputs, targets, auxiliary_vars):
518518
list(self.net.parameters()) + self.external_trainable_variables
519519
)
520520
self.opt = optimizers.get(
521-
trainable_variables, self.opt_name, learning_rate=lr, decay=decay
521+
trainable_variables,
522+
self.opt_name,
523+
learning_rate=lr,
524+
decay=decay,
525+
weight_decay=self.net.regularizer,
522526
)
523527

524528
def train_step(inputs, targets, auxiliary_vars):

deepxde/nn/paddle/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
"DeepONet",
55
"DeepONetCartesianProd",
66
"FNN",
7+
"MfNN",
78
"MsFFN",
89
"PFNN",
910
"STMsFFN",
1011
]
1112

1213
from .deeponet import DeepONet, DeepONetCartesianProd
1314
from .fnn import FNN, PFNN
15+
from .mfnn import MfNN
1416
from .msffn import MsFFN, STMsFFN

deepxde/nn/paddle/fnn.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,20 @@
33
from .nn import NN
44
from .. import activations
55
from .. import initializers
6+
from .. import regularizers
67

78

89
class FNN(NN):
910
"""Fully-connected neural network."""
1011

11-
def __init__(self, layer_sizes, activation, kernel_initializer):
12+
def __init__(
13+
self,
14+
layer_sizes,
15+
activation,
16+
kernel_initializer,
17+
regularization=None,
18+
dropout_rate=0,
19+
):
1220
super().__init__()
1321
if isinstance(activation, list):
1422
if not (len(layer_sizes) - 1) == len(activation):
@@ -20,6 +28,13 @@ def __init__(self, layer_sizes, activation, kernel_initializer):
2028
self.activation = activations.get(activation)
2129
initializer = initializers.get(kernel_initializer)
2230
initializer_zero = initializers.get("zeros")
31+
self.regularizer = regularizers.get(regularization)
32+
self.dropout_rate = dropout_rate
33+
if dropout_rate > 0:
34+
self.dropouts = [
35+
paddle.nn.Dropout(p=dropout_rate)
36+
for _ in range(1, len(layer_sizes) - 1)
37+
]
2338

2439
self.linears = paddle.nn.LayerList()
2540
for i in range(1, len(layer_sizes)):
@@ -37,6 +52,8 @@ def forward(self, inputs):
3752
if isinstance(self.activation, list)
3853
else self.activation(linear(x))
3954
)
55+
if self.dropout_rate > 0:
56+
x = self.dropouts[j](x)
4057
x = self.linears[-1](x)
4158
if self._output_transform is not None:
4259
x = self._output_transform(inputs, x)
@@ -58,11 +75,14 @@ class PFNN(NN):
5875
kernel_initializer: Initializer for the kernel weights matrix.
5976
"""
6077

61-
def __init__(self, layer_sizes, activation, kernel_initializer):
78+
def __init__(
79+
self, layer_sizes, activation, kernel_initializer, regularization=None
80+
):
6281
super().__init__()
6382
self.activation = activations.get(activation)
6483
initializer = initializers.get(kernel_initializer)
6584
initializer_zero = initializers.get("zeros")
85+
self.regularizer = regularizers.get(regularization)
6686

6787
if len(layer_sizes) <= 1:
6888
raise ValueError("must specify input and output sizes")
@@ -73,7 +93,6 @@ def __init__(self, layer_sizes, activation, kernel_initializer):
7393

7494
n_output = layer_sizes[-1]
7595

76-
7796
def make_linear(n_input, n_output):
7897
linear = paddle.nn.Linear(n_input, n_output)
7998
initializer(linear.weight)
@@ -92,18 +111,22 @@ def make_linear(n_input, n_output):
92111
if isinstance(prev_layer_size, (list, tuple)):
93112
# e.g. [8, 8, 8] -> [16, 16, 16]
94113
self.layers.append(
95-
paddle.nn.LayerList([
96-
make_linear(prev_layer_size[j], curr_layer_size[j])
97-
for j in range(n_output)
98-
])
114+
paddle.nn.LayerList(
115+
[
116+
make_linear(prev_layer_size[j], curr_layer_size[j])
117+
for j in range(n_output)
118+
]
119+
)
99120
)
100121
else:
101122
# e.g. 64 -> [8, 8, 8]
102123
self.layers.append(
103-
paddle.nn.LayerList([
104-
make_linear(prev_layer_size, curr_layer_size[j])
105-
for j in range(n_output)
106-
])
124+
paddle.nn.LayerList(
125+
[
126+
make_linear(prev_layer_size, curr_layer_size[j])
127+
for j in range(n_output)
128+
]
129+
)
107130
)
108131
else: # e.g. 64 -> 64
109132
if not isinstance(prev_layer_size, int):
@@ -115,10 +138,9 @@ def make_linear(n_input, n_output):
115138
# output layers
116139
if isinstance(layer_sizes[-2], (list, tuple)): # e.g. [3, 3, 3] -> 3
117140
self.layers.append(
118-
paddle.nn.LayerList([
119-
make_linear(layer_sizes[-2][j], 1)
120-
for j in range(n_output)
121-
])
141+
paddle.nn.LayerList(
142+
[make_linear(layer_sizes[-2][j], 1) for j in range(n_output)]
143+
)
122144
)
123145
else:
124146
self.layers.append(make_linear(layer_sizes[-2], n_output))

0 commit comments

Comments
 (0)