Skip to content

Commit af61a6f

Browse files
committed
Backend PyTorch: Add regularizer to DeepONet
1 parent 67ec746 commit af61a6f

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

deepxde/nn/pytorch/deeponet.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def __init__(
6868
kernel_initializer,
6969
num_outputs=1,
7070
multi_output_strategy=None,
71+
regularization=None,
7172
):
7273
super().__init__()
7374
if isinstance(activation, dict):
@@ -76,6 +77,7 @@ def __init__(
7677
else:
7778
self.activation_branch = self.activation_trunk = activations.get(activation)
7879
self.kernel_initializer = kernel_initializer
80+
self.regularizer = regularization
7981

8082
self.num_outputs = num_outputs
8183
if self.num_outputs == 1:
@@ -190,6 +192,7 @@ def __init__(
190192
kernel_initializer,
191193
num_outputs=1,
192194
multi_output_strategy=None,
195+
regularization=None,
193196
):
194197
super().__init__()
195198
if isinstance(activation, dict):
@@ -198,6 +201,7 @@ def __init__(
198201
else:
199202
self.activation_branch = self.activation_trunk = activations.get(activation)
200203
self.kernel_initializer = kernel_initializer
204+
self.regularizer = regularization
201205

202206
self.num_outputs = num_outputs
203207
if self.num_outputs == 1:
@@ -295,7 +299,7 @@ def __init__(
295299
regularization=None,
296300
):
297301
super().__init__()
298-
self.regularization = regularization # TODO: currently unused
302+
self.regularizer = regularization
299303
self.pod_basis = torch.as_tensor(pod_basis, dtype=torch.float32)
300304
if isinstance(activation, dict):
301305
activation_branch = activation["branch"]

0 commit comments

Comments
 (0)