@@ -58,6 +58,11 @@ class DeepONet(NN):
5858 Split the trunk net and share the branch net. The width of the last layer
5959 in the trunk net should be equal to the one in the branch net multiplied
6060 by the number of outputs.
61+ dropout_rate: If `dropout_rate` is a ``float`` between 0 and 1, then the
62+ same rate is used in both trunk and branch nets. If `dropout_rate`
63+ is a ``dict``, then the trunk net uses the rate `dropout_rate["trunk"]`,
64+ and the branch net uses `dropout_rate["branch"]`. Both `dropout_rate["trunk"]`
65+ and `dropout_rate["branch"]` should be ``float`` or lists of ``float``.
6166 """
6267
6368 def __init__ (
@@ -69,6 +74,7 @@ def __init__(
6974 num_outputs = 1 ,
7075 multi_output_strategy = None ,
7176 regularization = None ,
77+ dropout_rate = 0 ,
7278 ):
7379 super ().__init__ ()
7480 if isinstance (activation , dict ):
@@ -79,6 +85,12 @@ def __init__(
7985 self .kernel_initializer = kernel_initializer
8086 self .regularizer = regularization
8187
88+ if isinstance (dropout_rate , dict ):
89+ self .dropout_rate_branch = dropout_rate ["branch" ]
90+ self .dropout_rate_trunk = dropout_rate ["trunk" ]
91+ else :
92+ self .dropout_rate_branch = self .dropout_rate_trunk = dropout_rate
93+
8294 self .num_outputs = num_outputs
8395 if self .num_outputs == 1 :
8496 if multi_output_strategy is not None :
@@ -115,10 +127,20 @@ def build_branch_net(self, layer_sizes_branch):
115127 if callable (layer_sizes_branch [1 ]):
116128 return layer_sizes_branch [1 ]
117129 # Fully connected network
118- return FNN (layer_sizes_branch , self .activation_branch , self .kernel_initializer )
130+ return FNN (
131+ layer_sizes_branch ,
132+ self .activation_branch ,
133+ self .kernel_initializer ,
134+ dropout_rate = self .dropout_rate_branch ,
135+ )
119136
120137 def build_trunk_net (self , layer_sizes_trunk ):
121- return FNN (layer_sizes_trunk , self .activation_trunk , self .kernel_initializer )
138+ return FNN (
139+ layer_sizes_trunk ,
140+ self .activation_trunk ,
141+ self .kernel_initializer ,
142+ dropout_rate = self .dropout_rate_trunk ,
143+ )
122144
123145 def merge_branch_trunk (self , x_func , x_loc , index ):
124146 y = torch .einsum ("bi,bi->b" , x_func , x_loc )
@@ -182,6 +204,11 @@ class DeepONetCartesianProd(NN):
182204 Split the trunk net and share the branch net. The width of the last layer
183205 in the trunk net should be equal to the one in the branch net multiplied
184206 by the number of outputs.
207+ dropout_rate: If `dropout_rate` is a ``float`` between 0 and 1, then the
208+ same rate is used in both trunk and branch nets. If `dropout_rate`
209+ is a ``dict``, then the trunk net uses the rate `dropout_rate["trunk"]`,
210+ and the branch net uses `dropout_rate["branch"]`. Both `dropout_rate["trunk"]`
211+ and `dropout_rate["branch"]` should be ``float`` or lists of ``float``.
185212 """
186213
187214 def __init__ (
@@ -193,6 +220,7 @@ def __init__(
193220 num_outputs = 1 ,
194221 multi_output_strategy = None ,
195222 regularization = None ,
223+ dropout_rate = 0 ,
196224 ):
197225 super ().__init__ ()
198226 if isinstance (activation , dict ):
@@ -203,6 +231,12 @@ def __init__(
203231 self .kernel_initializer = kernel_initializer
204232 self .regularizer = regularization
205233
234+ if isinstance (dropout_rate , dict ):
235+ self .dropout_rate_branch = dropout_rate ["branch" ]
236+ self .dropout_rate_trunk = dropout_rate ["trunk" ]
237+ else :
238+ self .dropout_rate_branch = self .dropout_rate_trunk = dropout_rate
239+
206240 self .num_outputs = num_outputs
207241 if self .num_outputs == 1 :
208242 if multi_output_strategy is not None :
@@ -239,10 +273,20 @@ def build_branch_net(self, layer_sizes_branch):
239273 if callable (layer_sizes_branch [1 ]):
240274 return layer_sizes_branch [1 ]
241275 # Fully connected network
242- return FNN (layer_sizes_branch , self .activation_branch , self .kernel_initializer )
276+ return FNN (
277+ layer_sizes_branch ,
278+ self .activation_branch ,
279+ self .kernel_initializer ,
280+ dropout_rate = self .dropout_rate_branch ,
281+ )
243282
244283 def build_trunk_net (self , layer_sizes_trunk ):
245- return FNN (layer_sizes_trunk , self .activation_trunk , self .kernel_initializer )
284+ return FNN (
285+ layer_sizes_trunk ,
286+ self .activation_trunk ,
287+ self .kernel_initializer ,
288+ dropout_rate = self .dropout_rate_trunk ,
289+ )
246290
247291 def merge_branch_trunk (self , x_func , x_loc , index ):
248292 y = torch .einsum ("bi,ni->bn" , x_func , x_loc )
@@ -281,6 +325,11 @@ class PODDeepONet(NN):
281325 `activation["branch"]`.
282326 layer_sizes_trunk (list): A list of integers as the width of a fully connected
283327 network. If ``None``, then only use POD basis as the trunk net.
328+ dropout_rate: If `dropout_rate` is a ``float`` between 0 and 1, then the
329+ same rate is used in both trunk and branch nets. If `dropout_rate`
330+ is a ``dict``, then the trunk net uses the rate `dropout_rate["trunk"]`,
331+ and the branch net uses `dropout_rate["branch"]`. Both `dropout_rate["trunk"]`
332+ and `dropout_rate["branch"]` should be ``float`` or lists of ``float``.
284333
285334 References:
286335 `L. Lu, X. Meng, S. Cai, Z. Mao, S. Goswami, Z. Zhang, & G. E. Karniadakis. A
@@ -297,6 +346,7 @@ def __init__(
297346 kernel_initializer ,
298347 layer_sizes_trunk = None ,
299348 regularization = None ,
349+ dropout_rate = 0 ,
300350 ):
301351 super ().__init__ ()
302352 self .regularizer = regularization
@@ -307,17 +357,31 @@ def __init__(
307357 else :
308358 activation_branch = self .activation_trunk = activations .get (activation )
309359
360+ if isinstance (dropout_rate , dict ):
361+ dropout_rate_branch = dropout_rate ["branch" ]
362+ dropout_rate_trunk = dropout_rate ["trunk" ]
363+ else :
364+ dropout_rate_branch = dropout_rate_trunk = dropout_rate
365+
310366 if callable (layer_sizes_branch [1 ]):
311367 # User-defined network
312368 self .branch = layer_sizes_branch [1 ]
313369 else :
314370 # Fully connected network
315- self .branch = FNN (layer_sizes_branch , activation_branch , kernel_initializer )
371+ self .branch = FNN (
372+ layer_sizes_branch ,
373+ activation_branch ,
374+ kernel_initializer ,
375+ dropout_rate = dropout_rate_branch ,
376+ )
316377
317378 self .trunk = None
318379 if layer_sizes_trunk is not None :
319380 self .trunk = FNN (
320- layer_sizes_trunk , self .activation_trunk , kernel_initializer
381+ layer_sizes_trunk ,
382+ self .activation_trunk ,
383+ kernel_initializer ,
384+ dropout_rate = dropout_rate_trunk ,
321385 )
322386 self .b = torch .nn .parameter .Parameter (torch .tensor (0.0 ))
323387
0 commit comments