Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions deepxde/nn/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
"DeepONet",
"DeepONetCartesianProd",
"FNN",
"MfNN",
"MsFFN",
"PFNN",
"STMsFFN",
]

from .deeponet import DeepONet, DeepONetCartesianProd
from .fnn import FNN, PFNN
from .mfnn import MfNN
from .msffn import MsFFN, STMsFFN
43 changes: 29 additions & 14 deletions deepxde/nn/paddle/fnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,20 @@
from .nn import NN
from .. import activations
from .. import initializers
from .. import regularizers


class FNN(NN):
"""Fully-connected neural network."""

def __init__(self, layer_sizes, activation, kernel_initializer):
def __init__(
self,
layer_sizes,
activation,
kernel_initializer,
regularization=None,
dropout_rate=0.0,
):
super().__init__()
if isinstance(activation, list):
if not (len(layer_sizes) - 1) == len(activation):
Expand All @@ -20,12 +28,15 @@ def __init__(self, layer_sizes, activation, kernel_initializer):
self.activation = activations.get(activation)
initializer = initializers.get(kernel_initializer)
initializer_zero = initializers.get("zeros")
self.regularizer = regularizers.get(regularization)
self.dropout_rate = dropout_rate

self.linears = paddle.nn.LayerList()
for i in range(1, len(layer_sizes)):
self.linears.append(paddle.nn.Linear(layer_sizes[i - 1], layer_sizes[i]))
initializer(self.linears[-1].weight)
initializer_zero(self.linears[-1].bias)
self.dropout = paddle.nn.Dropout(p=dropout_rate) if dropout_rate > 0.0 else None

def forward(self, inputs):
x = inputs
Expand All @@ -37,6 +48,8 @@ def forward(self, inputs):
if isinstance(self.activation, list)
else self.activation(linear(x))
)
if self.dropout is not None:
x = self.dropout(x)
x = self.linears[-1](x)
if self._output_transform is not None:
x = self._output_transform(inputs, x)
Expand Down Expand Up @@ -73,7 +86,6 @@ def __init__(self, layer_sizes, activation, kernel_initializer):

n_output = layer_sizes[-1]


def make_linear(n_input, n_output):
linear = paddle.nn.Linear(n_input, n_output)
initializer(linear.weight)
Expand All @@ -92,18 +104,22 @@ def make_linear(n_input, n_output):
if isinstance(prev_layer_size, (list, tuple)):
# e.g. [8, 8, 8] -> [16, 16, 16]
self.layers.append(
paddle.nn.LayerList([
make_linear(prev_layer_size[j], curr_layer_size[j])
for j in range(n_output)
])
paddle.nn.LayerList(
[
make_linear(prev_layer_size[j], curr_layer_size[j])
for j in range(n_output)
]
)
)
else:
# e.g. 64 -> [8, 8, 8]
self.layers.append(
paddle.nn.LayerList([
make_linear(prev_layer_size, curr_layer_size[j])
for j in range(n_output)
])
paddle.nn.LayerList(
[
make_linear(prev_layer_size, curr_layer_size[j])
for j in range(n_output)
]
)
)
else: # e.g. 64 -> 64
if not isinstance(prev_layer_size, int):
Expand All @@ -115,10 +131,9 @@ def make_linear(n_input, n_output):
# output layers
if isinstance(layer_sizes[-2], (list, tuple)): # e.g. [3, 3, 3] -> 3
self.layers.append(
paddle.nn.LayerList([
make_linear(layer_sizes[-2][j], 1)
for j in range(n_output)
])
paddle.nn.LayerList(
[make_linear(layer_sizes[-2][j], 1) for j in range(n_output)]
)
)
else:
self.layers.append(make_linear(layer_sizes[-2], n_output))
Expand Down
120 changes: 120 additions & 0 deletions deepxde/nn/paddle/mfnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import paddle

from .nn import NN
from .. import activations
from .. import initializers
from .. import regularizers
from ... import config


class MfNN(NN):
"""Multifidelity neural networks."""

def __init__(
self,
layer_sizes_low_fidelity,
layer_sizes_high_fidelity,
activation,
kernel_initializer,
regularization=None,
residue=False,
trainable_low_fidelity=True,
trainable_high_fidelity=True,
):
super().__init__()
self.layer_size_lo = layer_sizes_low_fidelity
self.layer_size_hi = layer_sizes_high_fidelity

self.activation = activations.get(activation)
self.activation_tanh = activations.get("tanh")
self.initializer = initializers.get(kernel_initializer)
self.initializer_zero = initializers.get("zeros")
self.trainable_lo = trainable_low_fidelity
self.trainable_hi = trainable_high_fidelity
self.residue = residue
self.regularizer = regularizers.get(regularization)

# low fidelity
self.linears_lo = self.init_dense(self.layer_size_lo, self.trainable_lo)

# high fidelity
# linear part
self.linears_hi_l = paddle.nn.Linear(
in_features=self.layer_size_lo[0] + self.layer_size_lo[-1],
out_features=self.layer_size_hi[-1],
weight_attr=paddle.ParamAttr(initializer=self.initializer),
bias_attr=paddle.ParamAttr(initializer=self.initializer_zero),
)
if not self.trainable_hi:
for param in self.linears_hi_l.parameters():
param.stop_gradient = False
# nonlinear part
self.layer_size_hi = [
self.layer_size_lo[0] + self.layer_size_lo[-1]
] + self.layer_size_hi
self.linears_hi = self.init_dense(self.layer_size_hi, self.trainable_hi)
# linear + nonlinear
if not self.residue:
alpha = self.init_alpha(0.0, self.trainable_hi)
self.add_parameter("alpha", alpha)
else:
alpha1 = self.init_alpha(0.0, self.trainable_hi)
alpha2 = self.init_alpha(0.0, self.trainable_hi)
self.add_parameter("alpha1", alpha1)
self.add_parameter("alpha2", alpha2)

def init_dense(self, layer_size, trainable):
linears = paddle.nn.LayerList()
for i in range(len(layer_size) - 1):
linear = paddle.nn.Linear(
in_features=layer_size[i],
out_features=layer_size[i + 1],
weight_attr=paddle.ParamAttr(initializer=self.initializer),
bias_attr=paddle.ParamAttr(initializer=self.initializer_zero),
)
if not trainable:
for param in linear.parameters():
param.stop_gradient = False
linears.append(linear)
return linears

def init_alpha(self, value, trainable):
alpha = paddle.create_parameter(
shape=[1],
dtype=config.real(paddle),
default_initializer=paddle.nn.initializer.Constant(value),
)
alpha.stop_gradient = not trainable
return alpha

def forward(self, inputs):
x = inputs.astype(config.real(paddle))
# low fidelity
y = x
for i, linear in enumerate(self.linears_lo):
y = linear(y)
if i != len(self.linears_lo) - 1:
y = self.activation(y)
y_lo = y

# high fidelity
x_hi = paddle.concat([x, y_lo], axis=1)
# linear
y_hi_l = self.linears_hi_l(x_hi)
# nonlinear
y = x_hi
for i, linear in enumerate(self.linears_hi):
y = linear(y)
if i != len(self.linears_hi) - 1:
y = self.activation(y)
y_hi_nl = y
# linear + nonlinear
if not self.residue:
alpha = self.activation_tanh(self.alpha)
y_hi = y_hi_l + alpha * y_hi_nl
else:
alpha1 = self.activation_tanh(self.alpha1)
alpha2 = self.activation_tanh(self.alpha2)
y_hi = y_lo + 0.1 * (alpha1 * y_hi_l + alpha2 * y_hi_nl)

return y_lo, y_hi
Loading