Skip to content

Commit 9729318

Browse files
authored
Backend TensorFlow 1: DeepONet customized branches (#807)
1 parent 8057c8e commit 9729318

File tree

1 file changed

+28
-12
lines changed

1 file changed

+28
-12
lines changed

deepxde/nn/tensorflow_compat_v1/mionet.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,21 @@ def build(self):
5858
self._inputs = [self.X_func1, self.X_func2, self.X_loc]
5959

6060
# Branch net 1
61-
y_func1 = self._net(
62-
self.X_func1, self.layer_branch1[1:], self.activation_branch1
63-
)
61+
if callable(self.layer_branch1[1]):
62+
# User-defined network
63+
y_func1 = self.layer_branch1[1](self.X_func1)
64+
else:
65+
y_func1 = self._net(
66+
self.X_func1, self.layer_branch1[1:], self.activation_branch1
67+
)
6468
# Branch net 2
65-
y_func2 = self._net(
66-
self.X_func2, self.layer_branch2[1:], self.activation_branch2
67-
)
69+
if callable(self.layer_branch2[1]):
70+
# User-defined network
71+
y_func2 = self.layer_branch2[1](self.X_func2)
72+
else:
73+
y_func2 = self._net(
74+
self.X_func2, self.layer_branch2[1:], self.activation_branch2
75+
)
6876
# Trunk net
6977
y_loc = self._net(self.X_loc, self.layer_trunk[1:], self.activation_trunk)
7078

@@ -103,13 +111,21 @@ def build(self):
103111
self._inputs = [self.X_func1, self.X_func2, self.X_loc]
104112

105113
# Branch net 1
106-
y_func1 = self._net(
107-
self.X_func1, self.layer_branch1[1:], self.activation_branch1
108-
)
114+
if callable(self.layer_branch1[1]):
115+
# User-defined network
116+
y_func1 = self.layer_branch1[1](self.X_func1)
117+
else:
118+
y_func1 = self._net(
119+
self.X_func1, self.layer_branch1[1:], self.activation_branch1
120+
)
109121
# Branch net 2
110-
y_func2 = self._net(
111-
self.X_func2, self.layer_branch2[1:], self.activation_branch2
112-
)
122+
if callable(self.layer_branch2[1]):
123+
# User-defined network
124+
y_func2 = self.layer_branch2[1](self.X_func2)
125+
else:
126+
y_func2 = self._net(
127+
self.X_func2, self.layer_branch2[1:], self.activation_branch2
128+
)
113129
# Trunk net
114130
y_loc = self._net(self.X_loc, self.layer_trunk[1:], self.activation_trunk)
115131

0 commit comments

Comments
 (0)