Skip to content

Commit 37bb316

Browse files
committed
Pytorch backend: dropout rate for FNN/DeepONet
1 parent e341d76 commit 37bb316

File tree

2 files changed

+90
-7
lines changed

2 files changed

+90
-7
lines changed

deepxde/nn/pytorch/deeponet.py

Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ class DeepONet(NN):
5858
Split the trunk net and share the branch net. The width of the last layer
5959
in the trunk net should be equal to the one in the branch net multiplied
6060
by the number of outputs.
61+
dropout_rate: If `dropout_rate` is a ``float`` between 0 and 1, then the
62+
same rate is used in both trunk and branch nets. If `dropout_rate`
63+
is a ``dict``, then the trunk net uses the rate `dropout_rate["trunk"]`,
64+
and the branch net uses `dropout_rate["branch"]`. Both `dropout_rate["trunk"]`
65+
and `dropout_rate["branch"]` should be ``float`` or lists of ``float``.
6166
"""
6267

6368
def __init__(
@@ -69,6 +74,7 @@ def __init__(
6974
num_outputs=1,
7075
multi_output_strategy=None,
7176
regularization=None,
77+
dropout_rate=0,
7278
):
7379
super().__init__()
7480
if isinstance(activation, dict):
@@ -79,6 +85,12 @@ def __init__(
7985
self.kernel_initializer = kernel_initializer
8086
self.regularizer = regularization
8187

88+
if isinstance(dropout_rate, dict):
89+
self.dropout_rate_branch = dropout_rate["branch"]
90+
self.dropout_rate_trunk = dropout_rate["trunk"]
91+
else:
92+
self.dropout_rate_branch = self.dropout_rate_trunk = dropout_rate
93+
8294
self.num_outputs = num_outputs
8395
if self.num_outputs == 1:
8496
if multi_output_strategy is not None:
@@ -115,10 +127,20 @@ def build_branch_net(self, layer_sizes_branch):
115127
if callable(layer_sizes_branch[1]):
116128
return layer_sizes_branch[1]
117129
# Fully connected network
118-
return FNN(layer_sizes_branch, self.activation_branch, self.kernel_initializer)
130+
return FNN(
131+
layer_sizes_branch,
132+
self.activation_branch,
133+
self.kernel_initializer,
134+
dropout_rate=self.dropout_rate_branch,
135+
)
119136

120137
def build_trunk_net(self, layer_sizes_trunk):
121-
return FNN(layer_sizes_trunk, self.activation_trunk, self.kernel_initializer)
138+
return FNN(
139+
layer_sizes_trunk,
140+
self.activation_trunk,
141+
self.kernel_initializer,
142+
dropout_rate=self.dropout_rate_trunk,
143+
)
122144

123145
def merge_branch_trunk(self, x_func, x_loc, index):
124146
y = torch.einsum("bi,bi->b", x_func, x_loc)
@@ -182,6 +204,11 @@ class DeepONetCartesianProd(NN):
182204
Split the trunk net and share the branch net. The width of the last layer
183205
in the trunk net should be equal to the one in the branch net multiplied
184206
by the number of outputs.
207+
dropout_rate: If `dropout_rate` is a ``float`` between 0 and 1, then the
208+
same rate is used in both trunk and branch nets. If `dropout_rate`
209+
is a ``dict``, then the trunk net uses the rate `dropout_rate["trunk"]`,
210+
and the branch net uses `dropout_rate["branch"]`. Both `dropout_rate["trunk"]`
211+
and `dropout_rate["branch"]` should be ``float`` or lists of ``float``.
185212
"""
186213

187214
def __init__(
@@ -193,6 +220,7 @@ def __init__(
193220
num_outputs=1,
194221
multi_output_strategy=None,
195222
regularization=None,
223+
dropout_rate=0,
196224
):
197225
super().__init__()
198226
if isinstance(activation, dict):
@@ -203,6 +231,12 @@ def __init__(
203231
self.kernel_initializer = kernel_initializer
204232
self.regularizer = regularization
205233

234+
if isinstance(dropout_rate, dict):
235+
self.dropout_rate_branch = dropout_rate["branch"]
236+
self.dropout_rate_trunk = dropout_rate["trunk"]
237+
else:
238+
self.dropout_rate_branch = self.dropout_rate_trunk = dropout_rate
239+
206240
self.num_outputs = num_outputs
207241
if self.num_outputs == 1:
208242
if multi_output_strategy is not None:
@@ -239,10 +273,20 @@ def build_branch_net(self, layer_sizes_branch):
239273
if callable(layer_sizes_branch[1]):
240274
return layer_sizes_branch[1]
241275
# Fully connected network
242-
return FNN(layer_sizes_branch, self.activation_branch, self.kernel_initializer)
276+
return FNN(
277+
layer_sizes_branch,
278+
self.activation_branch,
279+
self.kernel_initializer,
280+
dropout_rate=self.dropout_rate_branch,
281+
)
243282

244283
def build_trunk_net(self, layer_sizes_trunk):
245-
return FNN(layer_sizes_trunk, self.activation_trunk, self.kernel_initializer)
284+
return FNN(
285+
layer_sizes_trunk,
286+
self.activation_trunk,
287+
self.kernel_initializer,
288+
dropout_rate=self.dropout_rate_trunk,
289+
)
246290

247291
def merge_branch_trunk(self, x_func, x_loc, index):
248292
y = torch.einsum("bi,ni->bn", x_func, x_loc)
@@ -281,6 +325,11 @@ class PODDeepONet(NN):
281325
`activation["branch"]`.
282326
layer_sizes_trunk (list): A list of integers as the width of a fully connected
283327
network. If ``None``, then only use POD basis as the trunk net.
328+
dropout_rate: If `dropout_rate` is a ``float`` between 0 and 1, then the
329+
same rate is used in both trunk and branch nets. If `dropout_rate`
330+
is a ``dict``, then the trunk net uses the rate `dropout_rate["trunk"]`,
331+
and the branch net uses `dropout_rate["branch"]`. Both `dropout_rate["trunk"]`
332+
and `dropout_rate["branch"]` should be ``float`` or lists of ``float``.
284333
285334
References:
286335
`L. Lu, X. Meng, S. Cai, Z. Mao, S. Goswami, Z. Zhang, & G. E. Karniadakis. A
@@ -297,6 +346,7 @@ def __init__(
297346
kernel_initializer,
298347
layer_sizes_trunk=None,
299348
regularization=None,
349+
dropout_rate=0,
300350
):
301351
super().__init__()
302352
self.regularizer = regularization
@@ -307,17 +357,31 @@ def __init__(
307357
else:
308358
activation_branch = self.activation_trunk = activations.get(activation)
309359

360+
if isinstance(dropout_rate, dict):
361+
dropout_rate_branch = dropout_rate["branch"]
362+
dropout_rate_trunk = dropout_rate["trunk"]
363+
else:
364+
dropout_rate_branch = dropout_rate_trunk = dropout_rate
365+
310366
if callable(layer_sizes_branch[1]):
311367
# User-defined network
312368
self.branch = layer_sizes_branch[1]
313369
else:
314370
# Fully connected network
315-
self.branch = FNN(layer_sizes_branch, activation_branch, kernel_initializer)
371+
self.branch = FNN(
372+
layer_sizes_branch,
373+
activation_branch,
374+
kernel_initializer,
375+
dropout_rate=dropout_rate_branch,
376+
)
316377

317378
self.trunk = None
318379
if layer_sizes_trunk is not None:
319380
self.trunk = FNN(
320-
layer_sizes_trunk, self.activation_trunk, kernel_initializer
381+
layer_sizes_trunk,
382+
self.activation_trunk,
383+
kernel_initializer,
384+
dropout_rate=dropout_rate_trunk,
321385
)
322386
self.b = torch.nn.parameter.Parameter(torch.tensor(0.0))
323387

deepxde/nn/pytorch/fnn.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@ class FNN(NN):
1010
"""Fully-connected neural network."""
1111

1212
def __init__(
13-
self, layer_sizes, activation, kernel_initializer, regularization=None
13+
self,
14+
layer_sizes,
15+
activation,
16+
kernel_initializer,
17+
regularization=None,
18+
dropout_rate=0,
1419
):
1520
super().__init__()
1621
if isinstance(activation, list):
@@ -21,6 +26,16 @@ def __init__(
2126
self.activation = list(map(activations.get, activation))
2227
else:
2328
self.activation = activations.get(activation)
29+
30+
if isinstance(dropout_rate, list):
31+
if not (len(layer_sizes) - 1) == len(dropout_rate):
32+
raise ValueError(
33+
f"Number of dropout rates must be equal to {len(layer_sizes) - 1}"
34+
)
35+
self.dropout_rate = dropout_rate
36+
else:
37+
self.dropout_rate = [dropout_rate] * (len(layer_sizes) - 1)
38+
2439
initializer = initializers.get(kernel_initializer)
2540
initializer_zero = initializers.get("zeros")
2641
self.regularizer = regularization
@@ -45,6 +60,10 @@ def forward(self, inputs):
4560
if isinstance(self.activation, list)
4661
else self.activation(linear(x))
4762
)
63+
if self.dropout_rate[j] > 0:
64+
x = torch.nn.functional.dropout(
65+
x, p=self.dropout_rate[j], training=self.training
66+
)
4867
x = self.linears[-1](x)
4968
if self._output_transform is not None:
5069
x = self._output_transform(inputs, x)

0 commit comments

Comments
 (0)