Skip to content

Commit 3b97ba9

Browse files
committed
Tensorflow 1.x backend: branch subnet refactoring for DeepONet
1 parent f7aa563 commit 3b97ba9

File tree

1 file changed

+35
-41
lines changed

1 file changed

+35
-41
lines changed

deepxde/nn/tensorflow_compat_v1/deeponet.py

Lines changed: 35 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -321,54 +321,48 @@ def build_branch_net(self):
321321
y_func = self.X_func
322322
if callable(self.layer_size_func[1]):
323323
# User-defined network
324-
y_func = self.layer_size_func[1](y_func)
325-
elif self.stacked:
326-
# Stacked fully connected network
327-
stack_size = self.layer_size_func[-1]
328-
for i in range(1, len(self.layer_size_func) - 1):
329-
y_func = self._stacked_dense(
330-
y_func,
331-
self.layer_size_func[i],
332-
stack_size,
333-
activation=self.activation_branch,
324+
return self.layer_size_func[1](y_func)
325+
326+
def _add_branch_layer(
327+
inputs, units, stack_size=None, activation=None, use_bias=True
328+
):
329+
if stack_size is None:
330+
return self._dense(
331+
inputs,
332+
units,
333+
activation=activation,
334+
regularizer=self.regularizer,
334335
trainable=self.trainable_branch,
336+
use_bias=use_bias,
335337
)
336-
if self.dropout_rate_branch[i - 1] > 0:
337-
y_func = tf.layers.dropout(
338-
y_func,
339-
rate=self.dropout_rate_branch[i - 1],
340-
training=self.training,
341-
)
342-
y_func = self._stacked_dense(
343-
y_func,
344-
1,
338+
return self._stacked_dense(
339+
inputs,
340+
units,
345341
stack_size,
346-
use_bias=self.use_bias,
342+
activation=activation,
347343
trainable=self.trainable_branch,
344+
use_bias=use_bias,
348345
)
349-
else:
350-
# Unstacked fully connected network
351-
for i in range(1, len(self.layer_size_func) - 1):
352-
y_func = self._dense(
353-
y_func,
354-
self.layer_size_func[i],
355-
activation=self.activation_branch,
356-
regularizer=self.regularizer,
357-
trainable=self.trainable_branch,
358-
)
359-
if self.dropout_rate_branch[i - 1] > 0:
360-
y_func = tf.layers.dropout(
361-
y_func,
362-
rate=self.dropout_rate_branch[i - 1],
363-
training=self.training,
364-
)
365-
y_func = self._dense(
346+
347+
for i in range(1, len(self.layer_size_func) - 1):
348+
y_func = _add_branch_layer(
366349
y_func,
367-
self.layer_size_func[-1],
368-
use_bias=self.use_bias,
369-
regularizer=self.regularizer,
370-
trainable=self.trainable_branch,
350+
self.layer_size_func[i],
351+
self.layer_size_func[-1] if self.stacked else None,
352+
activation=self.activation_branch,
371353
)
354+
if self.dropout_rate_branch[i - 1] > 0:
355+
y_func = tf.layers.dropout(
356+
y_func,
357+
rate=self.dropout_rate_branch[i - 1],
358+
training=self.training,
359+
)
360+
y_func = _add_branch_layer(
361+
y_func,
362+
1 if self.stacked else self.layer_size_func[-1],
363+
self.layer_size_func[-1] if self.stacked else None,
364+
use_bias=self.use_bias,
365+
)
372366
return y_func
373367

374368
def build_trunk_net(self):

0 commit comments

Comments
 (0)