Skip to content

Commit 1748bc7

Browse files
authored
Backend TensorFlow: Add regularization to DeepONet (#1602)
1 parent d42a6ca commit 1748bc7

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

deepxde/nn/tensorflow/deeponet.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ def __init__(
235235
kernel_initializer,
236236
num_outputs=1,
237237
multi_output_strategy=None,
238+
regularization=None,
238239
):
239240
super().__init__()
240241
if isinstance(activation, dict):
@@ -243,6 +244,7 @@ def __init__(
243244
else:
244245
self.activation_branch = self.activation_trunk = activations.get(activation)
245246
self.kernel_initializer = kernel_initializer
247+
self.regularization = regularization
246248

247249
self.num_outputs = num_outputs
248250
if self.num_outputs == 1:
@@ -280,10 +282,20 @@ def build_branch_net(self, layer_sizes_branch):
280282
if callable(layer_sizes_branch[1]):
281283
return layer_sizes_branch[1]
282284
# Fully connected network
283-
return FNN(layer_sizes_branch, self.activation_branch, self.kernel_initializer)
285+
return FNN(
286+
layer_sizes_branch,
287+
self.activation_branch,
288+
self.kernel_initializer,
289+
regularization=self.regularization,
290+
)
284291

285292
def build_trunk_net(self, layer_sizes_trunk):
286-
return FNN(layer_sizes_trunk, self.activation_trunk, self.kernel_initializer)
293+
return FNN(
294+
layer_sizes_trunk,
295+
self.activation_trunk,
296+
self.kernel_initializer,
297+
regularization=self.regularization,
298+
)
287299

288300
def merge_branch_trunk(self, x_func, x_loc):
289301
y = tf.einsum("bi,bi->b", x_func, x_loc)

0 commit comments

Comments
 (0)