Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 70 additions & 6 deletions deepxde/nn/pytorch/deeponet.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ class DeepONet(NN):
Split the trunk net and share the branch net. The width of the last layer
in the trunk net should be equal to the one in the branch net multiplied
by the number of outputs.
dropout_rate: If `dropout_rate` is a ``float`` between 0 and 1, then the
same rate is used in both trunk and branch nets. If `dropout_rate`
is a ``dict``, then the trunk net uses the rate `dropout_rate["trunk"]`,
and the branch net uses `dropout_rate["branch"]`. Both `dropout_rate["trunk"]`
and `dropout_rate["branch"]` should be ``float`` or lists of ``float``.
"""

def __init__(
Expand All @@ -69,6 +74,7 @@ def __init__(
num_outputs=1,
multi_output_strategy=None,
regularization=None,
dropout_rate=0,
):
super().__init__()
if isinstance(activation, dict):
Expand All @@ -79,6 +85,12 @@ def __init__(
self.kernel_initializer = kernel_initializer
self.regularizer = regularization

if isinstance(dropout_rate, dict):
self.dropout_rate_branch = dropout_rate["branch"]
self.dropout_rate_trunk = dropout_rate["trunk"]
else:
self.dropout_rate_branch = self.dropout_rate_trunk = dropout_rate

self.num_outputs = num_outputs
if self.num_outputs == 1:
if multi_output_strategy is not None:
Expand Down Expand Up @@ -115,10 +127,20 @@ def build_branch_net(self, layer_sizes_branch):
if callable(layer_sizes_branch[1]):
return layer_sizes_branch[1]
# Fully connected network
return FNN(layer_sizes_branch, self.activation_branch, self.kernel_initializer)
return FNN(
layer_sizes_branch,
self.activation_branch,
self.kernel_initializer,
dropout_rate=self.dropout_rate_branch,
)

def build_trunk_net(self, layer_sizes_trunk):
return FNN(layer_sizes_trunk, self.activation_trunk, self.kernel_initializer)
return FNN(
layer_sizes_trunk,
self.activation_trunk,
self.kernel_initializer,
dropout_rate=self.dropout_rate_trunk,
)

def merge_branch_trunk(self, x_func, x_loc, index):
y = torch.einsum("bi,bi->b", x_func, x_loc)
Expand Down Expand Up @@ -182,6 +204,11 @@ class DeepONetCartesianProd(NN):
Split the trunk net and share the branch net. The width of the last layer
in the trunk net should be equal to the one in the branch net multiplied
by the number of outputs.
dropout_rate: If `dropout_rate` is a ``float`` between 0 and 1, then the
same rate is used in both trunk and branch nets. If `dropout_rate`
is a ``dict``, then the trunk net uses the rate `dropout_rate["trunk"]`,
and the branch net uses `dropout_rate["branch"]`. Both `dropout_rate["trunk"]`
and `dropout_rate["branch"]` should be ``float`` or lists of ``float``.
"""

def __init__(
Expand All @@ -193,6 +220,7 @@ def __init__(
num_outputs=1,
multi_output_strategy=None,
regularization=None,
dropout_rate=0,
):
super().__init__()
if isinstance(activation, dict):
Expand All @@ -203,6 +231,12 @@ def __init__(
self.kernel_initializer = kernel_initializer
self.regularizer = regularization

if isinstance(dropout_rate, dict):
self.dropout_rate_branch = dropout_rate["branch"]
self.dropout_rate_trunk = dropout_rate["trunk"]
else:
self.dropout_rate_branch = self.dropout_rate_trunk = dropout_rate

self.num_outputs = num_outputs
if self.num_outputs == 1:
if multi_output_strategy is not None:
Expand Down Expand Up @@ -239,10 +273,20 @@ def build_branch_net(self, layer_sizes_branch):
if callable(layer_sizes_branch[1]):
return layer_sizes_branch[1]
# Fully connected network
return FNN(layer_sizes_branch, self.activation_branch, self.kernel_initializer)
return FNN(
layer_sizes_branch,
self.activation_branch,
self.kernel_initializer,
dropout_rate=self.dropout_rate_branch,
)

def build_trunk_net(self, layer_sizes_trunk):
return FNN(layer_sizes_trunk, self.activation_trunk, self.kernel_initializer)
return FNN(
layer_sizes_trunk,
self.activation_trunk,
self.kernel_initializer,
dropout_rate=self.dropout_rate_trunk,
)

def merge_branch_trunk(self, x_func, x_loc, index):
y = torch.einsum("bi,ni->bn", x_func, x_loc)
Expand Down Expand Up @@ -281,6 +325,11 @@ class PODDeepONet(NN):
`activation["branch"]`.
layer_sizes_trunk (list): A list of integers as the width of a fully connected
network. If ``None``, then only use POD basis as the trunk net.
dropout_rate: If `dropout_rate` is a ``float`` between 0 and 1, then the
same rate is used in both trunk and branch nets. If `dropout_rate`
is a ``dict``, then the trunk net uses the rate `dropout_rate["trunk"]`,
and the branch net uses `dropout_rate["branch"]`. Both `dropout_rate["trunk"]`
and `dropout_rate["branch"]` should be ``float`` or lists of ``float``.

References:
`L. Lu, X. Meng, S. Cai, Z. Mao, S. Goswami, Z. Zhang, & G. E. Karniadakis. A
Expand All @@ -297,6 +346,7 @@ def __init__(
kernel_initializer,
layer_sizes_trunk=None,
regularization=None,
dropout_rate=0,
):
super().__init__()
self.regularizer = regularization
Expand All @@ -307,17 +357,31 @@ def __init__(
else:
activation_branch = self.activation_trunk = activations.get(activation)

if isinstance(dropout_rate, dict):
dropout_rate_branch = dropout_rate["branch"]
dropout_rate_trunk = dropout_rate["trunk"]
else:
dropout_rate_branch = dropout_rate_trunk = dropout_rate

if callable(layer_sizes_branch[1]):
# User-defined network
self.branch = layer_sizes_branch[1]
else:
# Fully connected network
self.branch = FNN(layer_sizes_branch, activation_branch, kernel_initializer)
self.branch = FNN(
layer_sizes_branch,
activation_branch,
kernel_initializer,
dropout_rate=dropout_rate_branch,
)

self.trunk = None
if layer_sizes_trunk is not None:
self.trunk = FNN(
layer_sizes_trunk, self.activation_trunk, kernel_initializer
layer_sizes_trunk,
self.activation_trunk,
kernel_initializer,
dropout_rate=dropout_rate_trunk,
)
self.b = torch.nn.parameter.Parameter(torch.tensor(0.0))

Expand Down
21 changes: 20 additions & 1 deletion deepxde/nn/pytorch/fnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@ class FNN(NN):
"""Fully-connected neural network."""

def __init__(
self, layer_sizes, activation, kernel_initializer, regularization=None
self,
layer_sizes,
activation,
kernel_initializer,
regularization=None,
dropout_rate=0,
):
super().__init__()
if isinstance(activation, list):
Expand All @@ -21,6 +26,16 @@ def __init__(
self.activation = list(map(activations.get, activation))
else:
self.activation = activations.get(activation)

if isinstance(dropout_rate, list):
if not (len(layer_sizes) - 1) == len(dropout_rate):
raise ValueError(
f"Number of dropout rates must be equal to {len(layer_sizes) - 1}"
)
self.dropout_rate = dropout_rate
else:
self.dropout_rate = [dropout_rate] * (len(layer_sizes) - 1)

initializer = initializers.get(kernel_initializer)
initializer_zero = initializers.get("zeros")
self.regularizer = regularization
Expand All @@ -45,6 +60,10 @@ def forward(self, inputs):
if isinstance(self.activation, list)
else self.activation(linear(x))
)
if self.dropout_rate[j] > 0:
x = torch.nn.functional.dropout(
x, p=self.dropout_rate[j], training=self.training
)
x = self.linears[-1](x)
if self._output_transform is not None:
x = self._output_transform(inputs, x)
Expand Down