Skip to content

Commit 0a97e24

Browse files
committed
update code 1
1 parent 8491cec commit 0a97e24

File tree

1 file changed

+12
-14
lines changed

1 file changed

+12
-14
lines changed

deepxde/nn/paddle/mfnn.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def __init__(
2626
self.layer_size_hi = layer_sizes_high_fidelity
2727

2828
self.activation = activations.get(activation)
29-
self.activation_tanh = activations.get("tanh")
3029
self.initializer = initializers.get(kernel_initializer)
3130
self.initializer_zero = initializers.get("zeros")
3231
self.trainable_lo = trainable_low_fidelity
@@ -35,7 +34,7 @@ def __init__(
3534
self.regularizer = regularizers.get(regularization)
3635

3736
# low fidelity
38-
self.linears_lo = self.init_dense(self.layer_size_lo, self.trainable_lo)
37+
self.linears_lo = self._init_dense(self.layer_size_lo, self.trainable_lo)
3938

4039
# high fidelity
4140
# linear part
@@ -52,18 +51,18 @@ def __init__(
5251
self.layer_size_hi = [
5352
self.layer_size_lo[0] + self.layer_size_lo[-1]
5453
] + self.layer_size_hi
55-
self.linears_hi = self.init_dense(self.layer_size_hi, self.trainable_hi)
54+
self.linears_hi = self._init_dense(self.layer_size_hi, self.trainable_hi)
5655
# linear + nonlinear
5756
if not self.residue:
58-
alpha = self.init_alpha(0.0, self.trainable_hi)
57+
alpha = self._init_alpha(0.0, self.trainable_hi)
5958
self.add_parameter("alpha", alpha)
6059
else:
61-
alpha1 = self.init_alpha(0.0, self.trainable_hi)
62-
alpha2 = self.init_alpha(0.0, self.trainable_hi)
60+
alpha1 = self._init_alpha(0.0, self.trainable_hi)
61+
alpha2 = self._init_alpha(0.0, self.trainable_hi)
6362
self.add_parameter("alpha1", alpha1)
6463
self.add_parameter("alpha2", alpha2)
6564

66-
def init_dense(self, layer_size, trainable):
65+
def _init_dense(self, layer_size, trainable):
6766
linears = paddle.nn.LayerList()
6867
for i in range(len(layer_size) - 1):
6968
linear = paddle.nn.Linear(
@@ -78,7 +77,7 @@ def init_dense(self, layer_size, trainable):
7877
linears.append(linear)
7978
return linears
8079

81-
def init_alpha(self, value, trainable):
80+
def _init_alpha(self, value, trainable):
8281
alpha = paddle.create_parameter(
8382
shape=[1],
8483
dtype=config.real(paddle),
@@ -88,7 +87,8 @@ def init_alpha(self, value, trainable):
8887
return alpha
8988

9089
def forward(self, inputs):
91-
x = inputs.astype(config.real(paddle))
90+
x = inputs
91+
9292
# low fidelity
9393
y = x
9494
for i, linear in enumerate(self.linears_lo):
@@ -107,14 +107,12 @@ def forward(self, inputs):
107107
y = linear(y)
108108
if i != len(self.linears_hi) - 1:
109109
y = self.activation(y)
110-
y_hi_nl = y
111-
# linear + nonlinear
112110
if not self.residue:
113-
alpha = self.activation_tanh(self.alpha)
111+
alpha = paddle.tanh(self.alpha)
114112
y_hi = y_hi_l + alpha * y_hi_nl
115113
else:
116-
alpha1 = self.activation_tanh(self.alpha1)
117-
alpha2 = self.activation_tanh(self.alpha2)
114+
alpha1 = paddle.tanh(self.alpha1)
115+
alpha2 = paddle.tanh(self.alpha2)
118116
y_hi = y_lo + 0.1 * (alpha1 * y_hi_l + alpha2 * y_hi_nl)
119117

120118
return y_lo, y_hi

0 commit comments

Comments
 (0)