Skip to content

Commit f81ef5a

Browse files
committed
OpNN supports different activations for branch and trunk nets
1 parent 9374a9d commit f81ef5a

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

deepxde/maps/opnn.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@
1515

1616
class OpNN(Map):
1717
"""Operator neural networks.
18+
19+
Args:
20+
activation: If `activation` is a ``string``, then the same activation is used in both trunk and branch nets.
21+
If `activation` is a ``dict``, then the trunk net uses the activation `activation["trunk"]`,
22+
and the branch net uses `activation["branch"]`.
23+
trainable_branch (bool)
24+
trainable_trunk: Boolean or a list of booleans.
1825
"""
1926

2027
def __init__(
@@ -38,7 +45,11 @@ def __init__(
3845

3946
self.layer_size_func = layer_size_branch
4047
self.layer_size_loc = layer_size_trunk
41-
self.activation = activations.get(activation)
48+
if isinstance(activation, dict):
49+
self.activation_branch = activations.get(activation["branch"])
50+
self.activation_trunk = activations.get(activation["trunk"])
51+
else:
52+
self.activation_branch = self.activation_trunk = activations.get(activation)
4253
self.kernel_initializer = initializers.get(kernel_initializer)
4354
if stacked:
4455
self.kernel_initializer_stacked = initializers.get(
@@ -95,7 +106,7 @@ def build(self):
95106
y_func,
96107
self.layer_size_func[i],
97108
stack_size,
98-
activation=self.activation,
109+
activation=self.activation_branch,
99110
trainable=self.trainable_branch,
100111
)
101112
y_func = self.stacked_dense(
@@ -111,7 +122,7 @@ def build(self):
111122
y_func = self.dense(
112123
y_func,
113124
self.layer_size_func[i],
114-
activation=self.activation,
125+
activation=self.activation_branch,
115126
regularizer=self.regularizer,
116127
trainable=self.trainable_branch,
117128
)
@@ -128,7 +139,7 @@ def build(self):
128139
y_loc = self.dense(
129140
y_loc,
130141
self.layer_size_loc[i],
131-
activation=self.activation,
142+
activation=self.activation_trunk,
132143
regularizer=self.regularizer,
133144
trainable=self.trainable_trunk[i - 1]
134145
if isinstance(self.trainable_trunk, (list, tuple))

0 commit comments

Comments
 (0)