Skip to content

Commit f516cd0

Browse files
committed
Opnn supports user-defined branch net
1 parent ab957bc commit f516cd0

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

deepxde/maps/opnn.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,17 @@
1414

1515

1616
class OpNN(Map):
17-
"""Operator neural networks.
17+
"""Deep operator network.
1818
1919
Args:
20+
layer_size_branch: A list of integers as the width of a fully connected network, or `(dim, f)` where `dim` is
21+
the input dimension and `f` is a network function. The width of the last layer in the branch and trunk net
22+
should be equal.
23+
layer_size_trunk (list): A list of integers as the width of a fully connected network.
2024
activation: If `activation` is a ``string``, then the same activation is used in both trunk and branch nets.
2125
If `activation` is a ``dict``, then the trunk net uses the activation `activation["trunk"]`,
2226
and the branch net uses `activation["branch"]`.
23-
trainable_branch (bool)
27+
trainable_branch: Boolean.
2428
trainable_trunk: Boolean or a list of booleans.
2529
"""
2630

@@ -37,8 +41,6 @@ def __init__(
3741
trainable_trunk=True,
3842
):
3943
super(OpNN, self).__init__()
40-
if layer_size_branch[-1] != layer_size_trunk[-1]:
41-
raise ValueError("Output sizes of branch net and trunk net do not match.")
4244
if isinstance(trainable_trunk, (list, tuple)):
4345
if len(trainable_trunk) != len(layer_size_trunk) - 1:
4446
raise ValueError("trainable_trunk does not match layer_size_trunk.")
@@ -98,8 +100,11 @@ def build(self):
98100

99101
# Branch net to encode the input function
100102
y_func = self.X_func
101-
if self.stacked:
102-
# Stacked
103+
if callable(self.layer_size_func[1]):
104+
# User-defined network
105+
y_func = self.layer_size_func[1](y_func)
106+
elif self.stacked:
107+
# Stacked fully connected network
103108
stack_size = self.layer_size_func[-1]
104109
for i in range(1, len(self.layer_size_func) - 1):
105110
y_func = self.stacked_dense(
@@ -117,7 +122,7 @@ def build(self):
117122
trainable=self.trainable_branch,
118123
)
119124
else:
120-
# Unstacked
125+
# Unstacked fully connected network
121126
for i in range(1, len(self.layer_size_func) - 1):
122127
y_func = self.dense(
123128
y_func,
@@ -147,6 +152,10 @@ def build(self):
147152
)
148153

149154
# Dot product
155+
if y_func.get_shape().as_list()[-1] != y_loc.get_shape().as_list()[-1]:
156+
raise AssertionError(
157+
"Output sizes of branch net and trunk net do not match."
158+
)
150159
self.y = tf.einsum("bi,bi->b", y_func, y_loc)
151160
self.y = tf.expand_dims(self.y, axis=1)
152161
# Add bias

0 commit comments

Comments
 (0)