Skip to content

Commit ba8e824

Browse files
authored
Backend Tensorflow 1.x: branch subnet refactoring for DeepONet (#1849)
1 parent f7aa563 commit ba8e824

File tree

1 file changed

+47
-38
lines changed

1 file changed

+47
-38
lines changed

deepxde/nn/tensorflow_compat_v1/deeponet.py

Lines changed: 47 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -318,58 +318,67 @@ def build(self):
318318
self.built = True
319319

320320
def build_branch_net(self):
321-
y_func = self.X_func
322321
if callable(self.layer_size_func[1]):
323322
# User-defined network
324-
y_func = self.layer_size_func[1](y_func)
325-
elif self.stacked:
323+
return self.layer_size_func[1](self.X_func)
324+
325+
if self.stacked:
326326
# Stacked fully connected network
327-
stack_size = self.layer_size_func[-1]
328-
for i in range(1, len(self.layer_size_func) - 1):
329-
y_func = self._stacked_dense(
330-
y_func,
331-
self.layer_size_func[i],
332-
stack_size,
333-
activation=self.activation_branch,
334-
trainable=self.trainable_branch,
335-
)
336-
if self.dropout_rate_branch[i - 1] > 0:
337-
y_func = tf.layers.dropout(
338-
y_func,
339-
rate=self.dropout_rate_branch[i - 1],
340-
training=self.training,
341-
)
327+
return self._build_stacked_branch_net()
328+
329+
# Unstacked fully connected network
330+
return self._build_unstacked_branch_net()
331+
332+
def _build_stacked_branch_net(self):
333+
y_func = self.X_func
334+
stack_size = self.layer_size_func[-1]
335+
336+
for i in range(1, len(self.layer_size_func) - 1):
342337
y_func = self._stacked_dense(
343338
y_func,
344-
1,
345-
stack_size,
346-
use_bias=self.use_bias,
339+
self.layer_size_func[i],
340+
stack_size=stack_size,
341+
activation=self.activation_branch,
347342
trainable=self.trainable_branch,
348343
)
349-
else:
350-
# Unstacked fully connected network
351-
for i in range(1, len(self.layer_size_func) - 1):
352-
y_func = self._dense(
344+
if self.dropout_rate_branch[i - 1] > 0:
345+
y_func = tf.layers.dropout(
353346
y_func,
354-
self.layer_size_func[i],
355-
activation=self.activation_branch,
356-
regularizer=self.regularizer,
357-
trainable=self.trainable_branch,
347+
rate=self.dropout_rate_branch[i - 1],
348+
training=self.training,
358349
)
359-
if self.dropout_rate_branch[i - 1] > 0:
360-
y_func = tf.layers.dropout(
361-
y_func,
362-
rate=self.dropout_rate_branch[i - 1],
363-
training=self.training,
364-
)
350+
return self._stacked_dense(
351+
y_func,
352+
1,
353+
stack_size=stack_size,
354+
use_bias=self.use_bias,
355+
trainable=self.trainable_branch,
356+
)
357+
358+
def _build_unstacked_branch_net(self):
359+
y_func = self.X_func
360+
361+
for i in range(1, len(self.layer_size_func) - 1):
365362
y_func = self._dense(
366363
y_func,
367-
self.layer_size_func[-1],
368-
use_bias=self.use_bias,
364+
self.layer_size_func[i],
365+
activation=self.activation_branch,
369366
regularizer=self.regularizer,
370367
trainable=self.trainable_branch,
371368
)
372-
return y_func
369+
if self.dropout_rate_branch[i - 1] > 0:
370+
y_func = tf.layers.dropout(
371+
y_func,
372+
rate=self.dropout_rate_branch[i - 1],
373+
training=self.training,
374+
)
375+
return self._dense(
376+
y_func,
377+
self.layer_size_func[-1],
378+
use_bias=self.use_bias,
379+
regularizer=self.regularizer,
380+
trainable=self.trainable_branch,
381+
)
373382

374383
def build_trunk_net(self):
375384
y_loc = self.X_loc

0 commit comments

Comments
 (0)