Skip to content

Commit fd3311b

Browse files
author
Kuangdai Leng
authored
Multi-dimentional output strategy for DeepONet in PyTorch (#1643)
1 parent 0643941 commit fd3311b

File tree

3 files changed

+385
-246
lines changed

3 files changed

+385
-246
lines changed

deepxde/nn/deeponet_strategy.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
from abc import ABC, abstractmethod
2+
3+
4+
class DeepONetStrategy(ABC):
5+
"""DeepONet building strategy.
6+
7+
See the section 3.1.6. in
8+
L. Lu, X. Meng, S. Cai, Z. Mao, S. Goswami, Z. Zhang, & G. Karniadakis.
9+
A comprehensive and fair comparison of two neural operators
10+
(with practical extensions) based on FAIR data.
11+
Computer Methods in Applied Mechanics and Engineering, 393, 114778, 2022.
12+
"""
13+
14+
def __init__(self, net):
15+
self.net = net
16+
17+
@abstractmethod
18+
def build(self, layer_sizes_branch, layer_sizes_trunk):
19+
"""Build branch and trunk nets."""
20+
21+
@abstractmethod
22+
def call(self, x_func, x_loc):
23+
"""Forward pass."""
24+
25+
26+
class SingleOutputStrategy(DeepONetStrategy):
27+
"""Single output build strategy is the standard build method."""
28+
29+
def build(self, layer_sizes_branch, layer_sizes_trunk):
30+
if layer_sizes_branch[-1] != layer_sizes_trunk[-1]:
31+
raise AssertionError(
32+
"Output sizes of branch net and trunk net do not match."
33+
)
34+
branch = self.net.build_branch_net(layer_sizes_branch)
35+
trunk = self.net.build_trunk_net(layer_sizes_trunk)
36+
return branch, trunk
37+
38+
def call(self, x_func, x_loc):
39+
x_func = self.net.branch(x_func)
40+
x_loc = self.net.activation_trunk(self.net.trunk(x_loc))
41+
if x_func.shape[-1] != x_loc.shape[-1]:
42+
raise AssertionError(
43+
"Output sizes of branch net and trunk net do not match."
44+
)
45+
x = self.net.merge_branch_trunk(x_func, x_loc, 0)
46+
return x
47+
48+
49+
class IndependentStrategy(DeepONetStrategy):
50+
"""Directly use n independent DeepONets,
51+
and each DeepONet outputs only one function.
52+
"""
53+
54+
def build(self, layer_sizes_branch, layer_sizes_trunk):
55+
single_output_strategy = SingleOutputStrategy(self.net)
56+
branch, trunk = [], []
57+
for _ in range(self.net.num_outputs):
58+
branch_, trunk_ = single_output_strategy.build(
59+
layer_sizes_branch, layer_sizes_trunk
60+
)
61+
branch.append(branch_)
62+
trunk.append(trunk_)
63+
return branch, trunk
64+
65+
def call(self, x_func, x_loc):
66+
xs = []
67+
for i in range(self.net.num_outputs):
68+
x_func_ = self.net.branch[i](x_func)
69+
x_loc_ = self.net.activation_trunk(self.net.trunk[i](x_loc))
70+
x = self.net.merge_branch_trunk(x_func_, x_loc_, i)
71+
xs.append(x)
72+
return self.net.concatenate_outputs(xs)
73+
74+
75+
class SplitBothStrategy(DeepONetStrategy):
76+
"""Split the outputs of both the branch net and the trunk net into n groups,
77+
and then the kth group outputs the kth solution.
78+
79+
For example, if n = 2 and both the branch and trunk nets have 100 output neurons,
80+
then the dot product between the first 50 neurons of
81+
the branch and trunk nets generates the first function,
82+
and the remaining 50 neurons generate the second function.
83+
"""
84+
85+
def build(self, layer_sizes_branch, layer_sizes_trunk):
86+
if layer_sizes_branch[-1] != layer_sizes_trunk[-1]:
87+
raise AssertionError(
88+
"Output sizes of branch net and trunk net do not match."
89+
)
90+
if layer_sizes_branch[-1] % self.net.num_outputs != 0:
91+
raise AssertionError(
92+
f"Output size of the branch net is not evenly divisible by {self.net.num_outputs}."
93+
)
94+
single_output_strategy = SingleOutputStrategy(self.net)
95+
return single_output_strategy.build(layer_sizes_branch, layer_sizes_trunk)
96+
97+
def call(self, x_func, x_loc):
98+
x_func = self.net.branch(x_func)
99+
x_loc = self.net.activation_trunk(self.net.trunk(x_loc))
100+
# Split x_func and x_loc into respective outputs
101+
shift = 0
102+
size = x_func.shape[1] // self.net.num_outputs
103+
xs = []
104+
for i in range(self.net.num_outputs):
105+
x_func_ = x_func[:, shift : shift + size]
106+
x_loc_ = x_loc[:, shift : shift + size]
107+
x = self.net.merge_branch_trunk(x_func_, x_loc_, i)
108+
xs.append(x)
109+
shift += size
110+
return self.net.concatenate_outputs(xs)
111+
112+
113+
class SplitBranchStrategy(DeepONetStrategy):
114+
"""Split the branch net and share the trunk net."""
115+
116+
def build(self, layer_sizes_branch, layer_sizes_trunk):
117+
if layer_sizes_branch[-1] % self.net.num_outputs != 0:
118+
raise AssertionError(
119+
f"Output size of the branch net is not evenly divisible by {self.net.num_outputs}."
120+
)
121+
if layer_sizes_branch[-1] / self.net.num_outputs != layer_sizes_trunk[-1]:
122+
raise AssertionError(
123+
f"Output size of the trunk net does not equal to {layer_sizes_branch[-1] // self.net.num_outputs}."
124+
)
125+
return self.net.build_branch_net(layer_sizes_branch), self.net.build_trunk_net(
126+
layer_sizes_trunk
127+
)
128+
129+
def call(self, x_func, x_loc):
130+
x_func = self.net.branch(x_func)
131+
x_loc = self.net.activation_trunk(self.net.trunk(x_loc))
132+
# Split x_func into respective outputs
133+
shift = 0
134+
size = x_loc.shape[1]
135+
xs = []
136+
for i in range(self.net.num_outputs):
137+
x_func_ = x_func[:, shift : shift + size]
138+
x = self.net.merge_branch_trunk(x_func_, x_loc, i)
139+
xs.append(x)
140+
shift += size
141+
return self.net.concatenate_outputs(xs)
142+
143+
144+
class SplitTrunkStrategy(DeepONetStrategy):
145+
"""Split the trunk net and share the branch net."""
146+
147+
def build(self, layer_sizes_branch, layer_sizes_trunk):
148+
if layer_sizes_trunk[-1] % self.net.num_outputs != 0:
149+
raise AssertionError(
150+
f"Output size of the trunk net is not evenly divisible by {self.net.num_outputs}."
151+
)
152+
if layer_sizes_trunk[-1] / self.net.num_outputs != layer_sizes_branch[-1]:
153+
raise AssertionError(
154+
f"Output size of the branch net does not equal to {layer_sizes_trunk[-1] // self.net.num_outputs}."
155+
)
156+
return self.net.build_branch_net(layer_sizes_branch), self.net.build_trunk_net(
157+
layer_sizes_trunk
158+
)
159+
160+
def call(self, x_func, x_loc):
161+
x_func = self.net.branch(x_func)
162+
x_loc = self.net.activation_trunk(self.net.trunk(x_loc))
163+
# Split x_loc into respective outputs
164+
shift = 0
165+
size = x_func.shape[1]
166+
xs = []
167+
for i in range(self.net.num_outputs):
168+
x_loc_ = x_loc[:, shift : shift + size]
169+
x = self.net.merge_branch_trunk(x_func, x_loc_, i)
170+
xs.append(x)
171+
shift += size
172+
return self.net.concatenate_outputs(xs)

0 commit comments

Comments
 (0)