Skip to content

Commit 6516c27

Browse files
Backend Paddle: Add DeepONet (#935)
1 parent c869921 commit 6516c27

File tree

3 files changed

+82
-2
lines changed

3 files changed

+82
-2
lines changed

deepxde/nn/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
# The backend should be tensorflow/tensorflow.compat.v1 to ensure backend.tf is not
2525
# None.
2626
from . import jax
27-
from . import paddle
27+
from . import paddle
2828
from . import pytorch
2929
from . import tensorflow
3030
from . import tensorflow_compat_v1

deepxde/nn/paddle/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Package for paddle NN modules."""
22

3-
__all__ = ["DeepONetCartesianProd", "FNN"]
3+
__all__ = ["DeepONet", "DeepONetCartesianProd", "FNN"]
44

5+
from .deeponet import DeepONet
56
from .deeponet import DeepONetCartesianProd
67
from .fnn import FNN

deepxde/nn/paddle/deeponet.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,85 @@
66
from .. import initializers
77

88

9+
class DeepONet(NN):
10+
"""Deep operator network.
11+
12+
`Lu et al. Learning nonlinear operators via DeepONet based on the universal
13+
approximation theorem of operators. Nat Mach Intell, 2021.
14+
<https://doi.org/10.1038/s42256-021-00302-5>`_
15+
16+
Args:
17+
layer_sizes_branch: A list of integers as the width of a fully connected
18+
network, or `(dim, f)` where `dim` is the input dimension and `f` is a
19+
network function. The width of the last layer in the branch and trunk net
20+
should be equal.
21+
layer_sizes_trunk (list): A list of integers as the width of a fully connected
22+
network.
23+
activation: If `activation` is a ``string``, then the same activation is used in
24+
both trunk and branch nets. If `activation` is a ``dict``, then the trunk
25+
net uses the activation `activation["trunk"]`, and the branch net uses
26+
`activation["branch"]`.
27+
"""
28+
29+
def __init__(
30+
self,
31+
layer_sizes_branch,
32+
layer_sizes_trunk,
33+
activation,
34+
kernel_initializer,
35+
use_bias=True,
36+
):
37+
super().__init__()
38+
self.layer_sizes_func = layer_sizes_branch
39+
self.layer_sizes_loc = layer_sizes_trunk
40+
41+
if isinstance(activation, dict):
42+
self.activation_branch = activations.get(activation["branch"])
43+
self.activation_trunk = activations.get(activation["trunk"])
44+
else:
45+
activation_branch = self.activation_trunk = activations.get(activation)
46+
47+
self.kernel_initializer = initializers.get(kernel_initializer)
48+
49+
if callable(layer_sizes_branch[1]):
50+
# User-defined network
51+
self.branch = layer_sizes_branch[1]
52+
else:
53+
# Fully connected network
54+
self.branch = FNN(layer_sizes_branch, activation_branch, kernel_initializer)
55+
self.trunk = FNN(layer_sizes_trunk, self.activation_trunk, kernel_initializer)
56+
self.use_bias = use_bias
57+
if use_bias:
58+
# register bias to parameter for updating in optimizer and storage
59+
self.b = self.create_parameter(
60+
shape=(1, ),
61+
default_initializer=initializers.get("zeros")
62+
)
63+
64+
def forward(self, inputs):
65+
x_func = inputs[0]
66+
x_loc = inputs[1]
67+
# Branch net to encode the input function
68+
x_func = self.branch(x_func)
69+
# Trunk net to encode the domain of the output function
70+
if self._input_transform is not None:
71+
x_loc = self._input_transform(x_loc)
72+
x_loc = self.activation_trunk(self.trunk(x_loc))
73+
# Dot product
74+
if x_func.shape[-1] != x_loc.shape[-1]:
75+
raise AssertionError(
76+
"Output sizes of branch net and trunk net do not match."
77+
)
78+
x = paddle.einsum("bi,bi->b", x_func, x_loc) # [batch_size, ]
79+
x = paddle.reshape(x, [-1, 1]) # reshape [batch_size, ] to [batch_size, 1]
80+
# Add bias
81+
if self.use_bias:
82+
x += self.b
83+
if self._output_transform is not None:
84+
x = self._output_transform(inputs, x)
85+
return x
86+
87+
988
class DeepONetCartesianProd(NN):
1089
"""Deep operator network for dataset in the format of Cartesian product.
1190

0 commit comments

Comments
 (0)