|
| 1 | +from abc import ABC, abstractmethod |
| 2 | + |
| 3 | + |
| 4 | +class DeepONetStrategy(ABC): |
| 5 | + """DeepONet building strategy. |
| 6 | +
|
| 7 | + See the section 3.1.6. in |
| 8 | + L. Lu, X. Meng, S. Cai, Z. Mao, S. Goswami, Z. Zhang, & G. Karniadakis. |
| 9 | + A comprehensive and fair comparison of two neural operators |
| 10 | + (with practical extensions) based on FAIR data. |
| 11 | + Computer Methods in Applied Mechanics and Engineering, 393, 114778, 2022. |
| 12 | + """ |
| 13 | + |
| 14 | + def __init__(self, net): |
| 15 | + self.net = net |
| 16 | + |
| 17 | + @abstractmethod |
| 18 | + def build(self, layer_sizes_branch, layer_sizes_trunk): |
| 19 | + """Build branch and trunk nets.""" |
| 20 | + |
| 21 | + @abstractmethod |
| 22 | + def call(self, x_func, x_loc): |
| 23 | + """Forward pass.""" |
| 24 | + |
| 25 | + |
| 26 | +class SingleOutputStrategy(DeepONetStrategy): |
| 27 | + """Single output build strategy is the standard build method.""" |
| 28 | + |
| 29 | + def build(self, layer_sizes_branch, layer_sizes_trunk): |
| 30 | + if layer_sizes_branch[-1] != layer_sizes_trunk[-1]: |
| 31 | + raise AssertionError( |
| 32 | + "Output sizes of branch net and trunk net do not match." |
| 33 | + ) |
| 34 | + branch = self.net.build_branch_net(layer_sizes_branch) |
| 35 | + trunk = self.net.build_trunk_net(layer_sizes_trunk) |
| 36 | + return branch, trunk |
| 37 | + |
| 38 | + def call(self, x_func, x_loc): |
| 39 | + x_func = self.net.branch(x_func) |
| 40 | + x_loc = self.net.activation_trunk(self.net.trunk(x_loc)) |
| 41 | + if x_func.shape[-1] != x_loc.shape[-1]: |
| 42 | + raise AssertionError( |
| 43 | + "Output sizes of branch net and trunk net do not match." |
| 44 | + ) |
| 45 | + x = self.net.merge_branch_trunk(x_func, x_loc, 0) |
| 46 | + return x |
| 47 | + |
| 48 | + |
| 49 | +class IndependentStrategy(DeepONetStrategy): |
| 50 | + """Directly use n independent DeepONets, |
| 51 | + and each DeepONet outputs only one function. |
| 52 | + """ |
| 53 | + |
| 54 | + def build(self, layer_sizes_branch, layer_sizes_trunk): |
| 55 | + single_output_strategy = SingleOutputStrategy(self.net) |
| 56 | + branch, trunk = [], [] |
| 57 | + for _ in range(self.net.num_outputs): |
| 58 | + branch_, trunk_ = single_output_strategy.build( |
| 59 | + layer_sizes_branch, layer_sizes_trunk |
| 60 | + ) |
| 61 | + branch.append(branch_) |
| 62 | + trunk.append(trunk_) |
| 63 | + return branch, trunk |
| 64 | + |
| 65 | + def call(self, x_func, x_loc): |
| 66 | + xs = [] |
| 67 | + for i in range(self.net.num_outputs): |
| 68 | + x_func_ = self.net.branch[i](x_func) |
| 69 | + x_loc_ = self.net.activation_trunk(self.net.trunk[i](x_loc)) |
| 70 | + x = self.net.merge_branch_trunk(x_func_, x_loc_, i) |
| 71 | + xs.append(x) |
| 72 | + return self.net.concatenate_outputs(xs) |
| 73 | + |
| 74 | + |
| 75 | +class SplitBothStrategy(DeepONetStrategy): |
| 76 | + """Split the outputs of both the branch net and the trunk net into n groups, |
| 77 | + and then the kth group outputs the kth solution. |
| 78 | +
|
| 79 | + For example, if n = 2 and both the branch and trunk nets have 100 output neurons, |
| 80 | + then the dot product between the first 50 neurons of |
| 81 | + the branch and trunk nets generates the first function, |
| 82 | + and the remaining 50 neurons generate the second function. |
| 83 | + """ |
| 84 | + |
| 85 | + def build(self, layer_sizes_branch, layer_sizes_trunk): |
| 86 | + if layer_sizes_branch[-1] != layer_sizes_trunk[-1]: |
| 87 | + raise AssertionError( |
| 88 | + "Output sizes of branch net and trunk net do not match." |
| 89 | + ) |
| 90 | + if layer_sizes_branch[-1] % self.net.num_outputs != 0: |
| 91 | + raise AssertionError( |
| 92 | + f"Output size of the branch net is not evenly divisible by {self.net.num_outputs}." |
| 93 | + ) |
| 94 | + single_output_strategy = SingleOutputStrategy(self.net) |
| 95 | + return single_output_strategy.build(layer_sizes_branch, layer_sizes_trunk) |
| 96 | + |
| 97 | + def call(self, x_func, x_loc): |
| 98 | + x_func = self.net.branch(x_func) |
| 99 | + x_loc = self.net.activation_trunk(self.net.trunk(x_loc)) |
| 100 | + # Split x_func and x_loc into respective outputs |
| 101 | + shift = 0 |
| 102 | + size = x_func.shape[1] // self.net.num_outputs |
| 103 | + xs = [] |
| 104 | + for i in range(self.net.num_outputs): |
| 105 | + x_func_ = x_func[:, shift : shift + size] |
| 106 | + x_loc_ = x_loc[:, shift : shift + size] |
| 107 | + x = self.net.merge_branch_trunk(x_func_, x_loc_, i) |
| 108 | + xs.append(x) |
| 109 | + shift += size |
| 110 | + return self.net.concatenate_outputs(xs) |
| 111 | + |
| 112 | + |
| 113 | +class SplitBranchStrategy(DeepONetStrategy): |
| 114 | + """Split the branch net and share the trunk net.""" |
| 115 | + |
| 116 | + def build(self, layer_sizes_branch, layer_sizes_trunk): |
| 117 | + if layer_sizes_branch[-1] % self.net.num_outputs != 0: |
| 118 | + raise AssertionError( |
| 119 | + f"Output size of the branch net is not evenly divisible by {self.net.num_outputs}." |
| 120 | + ) |
| 121 | + if layer_sizes_branch[-1] / self.net.num_outputs != layer_sizes_trunk[-1]: |
| 122 | + raise AssertionError( |
| 123 | + f"Output size of the trunk net does not equal to {layer_sizes_branch[-1] // self.net.num_outputs}." |
| 124 | + ) |
| 125 | + return self.net.build_branch_net(layer_sizes_branch), self.net.build_trunk_net( |
| 126 | + layer_sizes_trunk |
| 127 | + ) |
| 128 | + |
| 129 | + def call(self, x_func, x_loc): |
| 130 | + x_func = self.net.branch(x_func) |
| 131 | + x_loc = self.net.activation_trunk(self.net.trunk(x_loc)) |
| 132 | + # Split x_func into respective outputs |
| 133 | + shift = 0 |
| 134 | + size = x_loc.shape[1] |
| 135 | + xs = [] |
| 136 | + for i in range(self.net.num_outputs): |
| 137 | + x_func_ = x_func[:, shift : shift + size] |
| 138 | + x = self.net.merge_branch_trunk(x_func_, x_loc, i) |
| 139 | + xs.append(x) |
| 140 | + shift += size |
| 141 | + return self.net.concatenate_outputs(xs) |
| 142 | + |
| 143 | + |
| 144 | +class SplitTrunkStrategy(DeepONetStrategy): |
| 145 | + """Split the trunk net and share the branch net.""" |
| 146 | + |
| 147 | + def build(self, layer_sizes_branch, layer_sizes_trunk): |
| 148 | + if layer_sizes_trunk[-1] % self.net.num_outputs != 0: |
| 149 | + raise AssertionError( |
| 150 | + f"Output size of the trunk net is not evenly divisible by {self.net.num_outputs}." |
| 151 | + ) |
| 152 | + if layer_sizes_trunk[-1] / self.net.num_outputs != layer_sizes_branch[-1]: |
| 153 | + raise AssertionError( |
| 154 | + f"Output size of the branch net does not equal to {layer_sizes_trunk[-1] // self.net.num_outputs}." |
| 155 | + ) |
| 156 | + return self.net.build_branch_net(layer_sizes_branch), self.net.build_trunk_net( |
| 157 | + layer_sizes_trunk |
| 158 | + ) |
| 159 | + |
| 160 | + def call(self, x_func, x_loc): |
| 161 | + x_func = self.net.branch(x_func) |
| 162 | + x_loc = self.net.activation_trunk(self.net.trunk(x_loc)) |
| 163 | + # Split x_loc into respective outputs |
| 164 | + shift = 0 |
| 165 | + size = x_func.shape[1] |
| 166 | + xs = [] |
| 167 | + for i in range(self.net.num_outputs): |
| 168 | + x_loc_ = x_loc[:, shift : shift + size] |
| 169 | + x = self.net.merge_branch_trunk(x_func, x_loc_, i) |
| 170 | + xs.append(x) |
| 171 | + shift += size |
| 172 | + return self.net.concatenate_outputs(xs) |
0 commit comments