1515
1616class OpNN (Map ):
1717 """Operator neural networks.
18+
19+ Args:
20+ activation: If `activation` is a ``string``, then the same activation is used in both trunk and branch nets.
21+ If `activation` is a ``dict``, then the trunk net uses the activation `activation["trunk"]`,
22+ and the branch net uses `activation["branch"]`.
23+ trainable_branch (bool)
24+ trainable_trunk: Boolean or a list of booleans.
1825 """
1926
2027 def __init__ (
@@ -38,7 +45,11 @@ def __init__(
3845
3946 self .layer_size_func = layer_size_branch
4047 self .layer_size_loc = layer_size_trunk
41- self .activation = activations .get (activation )
48+ if isinstance (activation , dict ):
49+ self .activation_branch = activations .get (activation ["branch" ])
50+ self .activation_trunk = activations .get (activation ["trunk" ])
51+ else :
52+ self .activation_branch = self .activation_trunk = activations .get (activation )
4253 self .kernel_initializer = initializers .get (kernel_initializer )
4354 if stacked :
4455 self .kernel_initializer_stacked = initializers .get (
@@ -95,7 +106,7 @@ def build(self):
95106 y_func ,
96107 self .layer_size_func [i ],
97108 stack_size ,
98- activation = self .activation ,
109+ activation = self .activation_branch ,
99110 trainable = self .trainable_branch ,
100111 )
101112 y_func = self .stacked_dense (
@@ -111,7 +122,7 @@ def build(self):
111122 y_func = self .dense (
112123 y_func ,
113124 self .layer_size_func [i ],
114- activation = self .activation ,
125+ activation = self .activation_branch ,
115126 regularizer = self .regularizer ,
116127 trainable = self .trainable_branch ,
117128 )
@@ -128,7 +139,7 @@ def build(self):
128139 y_loc = self .dense (
129140 y_loc ,
130141 self .layer_size_loc [i ],
131- activation = self .activation ,
142+ activation = self .activation_trunk ,
132143 regularizer = self .regularizer ,
133144 trainable = self .trainable_trunk [i - 1 ]
134145 if isinstance (self .trainable_trunk , (list , tuple ))
0 commit comments