1414
1515
1616class OpNN (Map ):
17- """Operator neural networks .
17+ """Deep operator network .
1818
1919 Args:
20+ layer_size_branch: A list of integers as the width of a fully connected network, or `(dim, f)` where `dim` is
21+ the input dimension and `f` is a network function. The width of the last layer in the branch and trunk net
22+ should be equal.
23+ layer_size_trunk (list): A list of integers as the width of a fully connected network.
2024 activation: If `activation` is a ``string``, then the same activation is used in both trunk and branch nets.
2125 If `activation` is a ``dict``, then the trunk net uses the activation `activation["trunk"]`,
2226 and the branch net uses `activation["branch"]`.
23- trainable_branch (bool)
27+ trainable_branch: Boolean.
2428 trainable_trunk: Boolean or a list of booleans.
2529 """
2630
@@ -37,8 +41,6 @@ def __init__(
3741 trainable_trunk = True ,
3842 ):
3943 super (OpNN , self ).__init__ ()
40- if layer_size_branch [- 1 ] != layer_size_trunk [- 1 ]:
41- raise ValueError ("Output sizes of branch net and trunk net do not match." )
4244 if isinstance (trainable_trunk , (list , tuple )):
4345 if len (trainable_trunk ) != len (layer_size_trunk ) - 1 :
4446 raise ValueError ("trainable_trunk does not match layer_size_trunk." )
@@ -98,8 +100,11 @@ def build(self):
98100
99101 # Branch net to encode the input function
100102 y_func = self .X_func
101- if self .stacked :
102- # Stacked
103+ if callable (self .layer_size_func [1 ]):
104+ # User-defined network
105+ y_func = self .layer_size_func [1 ](y_func )
106+ elif self .stacked :
107+ # Stacked fully connected network
103108 stack_size = self .layer_size_func [- 1 ]
104109 for i in range (1 , len (self .layer_size_func ) - 1 ):
105110 y_func = self .stacked_dense (
@@ -117,7 +122,7 @@ def build(self):
117122 trainable = self .trainable_branch ,
118123 )
119124 else :
120- # Unstacked
125+ # Unstacked fully connected network
121126 for i in range (1 , len (self .layer_size_func ) - 1 ):
122127 y_func = self .dense (
123128 y_func ,
@@ -147,6 +152,10 @@ def build(self):
147152 )
148153
149154 # Dot product
155+ if y_func .get_shape ().as_list ()[- 1 ] != y_loc .get_shape ().as_list ()[- 1 ]:
156+ raise AssertionError (
157+ "Output sizes of branch net and trunk net do not match."
158+ )
150159 self .y = tf .einsum ("bi,bi->b" , y_func , y_loc )
151160 self .y = tf .expand_dims (self .y , axis = 1 )
152161 # Add bias
0 commit comments