44from .nn import NN
55from .. import activations
66from .. import initializers
7+ from ..deeponet_strategy import (
8+ SingleOutputStrategy ,
9+ IndependentStrategy ,
10+ SplitBothStrategy ,
11+ SplitBranchStrategy ,
12+ SplitTrunkStrategy ,
13+ )
714
815
916class DeepONet (NN ):
@@ -89,14 +96,40 @@ class DeepONetCartesianProd(NN):
8996 Args:
9097 layer_sizes_branch: A list of integers as the width of a fully connected network,
9198 or `(dim, f)` where `dim` is the input dimension and `f` is a network
92- function. The width of the last layer in the branch and trunk net should be
93- equal .
99+ function. The width of the last layer in the branch and trunk net
100+ should be the same for all strategies except "split_branch" and "split_trunk" .
94101 layer_sizes_trunk (list): A list of integers as the width of a fully connected
95102 network.
96103 activation: If `activation` is a ``string``, then the same activation is used in
97104 both trunk and branch nets. If `activation` is a ``dict``, then the trunk
98105 net uses the activation `activation["trunk"]`, and the branch net uses
99106 `activation["branch"]`.
107+ num_outputs (integer): Number of outputs. In case of multiple outputs, i.e., `num_outputs` > 1,
108+ `multi_output_strategy` below should be set.
109+ multi_output_strategy (str or None): ``None``, "independent", "split_both", "split_branch" or
110+ "split_trunk". It makes sense to set in case of multiple outputs.
111+
112+ - None
113+ Classical implementation of DeepONet with a single output.
114+ Cannot be used with `num_outputs` > 1.
115+
116+ - independent
117+ Use `num_outputs` independent DeepONets, and each DeepONet outputs only
118+ one function.
119+
120+ - split_both
121+ Split the outputs of both the branch net and the trunk net into `num_outputs`
122+ groups, and then the kth group outputs the kth solution.
123+
124+ - split_branch
125+ Split the branch net and share the trunk net. The width of the last layer
126+ in the branch net should be equal to the one in the trunk net multiplied
127+ by the number of outputs.
128+
129+ - split_trunk
130+ Split the trunk net and share the branch net. The width of the last layer
131+ in the trunk net should be equal to the one in the branch net multiplied
132+ by the number of outputs.
100133 """
101134
102135 def __init__ (
@@ -105,45 +138,81 @@ def __init__(
105138 layer_sizes_trunk ,
106139 activation ,
107140 kernel_initializer ,
108- regularization = None ,
141+ num_outputs = 1 ,
142+ multi_output_strategy = None ,
109143 ):
110144 super ().__init__ ()
111145 if isinstance (activation , dict ):
112- activation_branch = activation ["branch" ]
146+ self . activation_branch = activation ["branch" ]
113147 self .activation_trunk = activations .get (activation ["trunk" ])
114148 else :
115- activation_branch = self .activation_trunk = activations .get (activation )
116- if callable (layer_sizes_branch [1 ]):
117- # User-defined network
118- self .branch = layer_sizes_branch [1 ]
119- else :
120- # Fully connected network
121- self .branch = FNN (layer_sizes_branch , activation_branch , kernel_initializer )
122- self .trunk = FNN (layer_sizes_trunk , self .activation_trunk , kernel_initializer )
123- # register bias to parameter for updating in optimizer and storage
124- self .b = self .create_parameter (
125- shape = (1 ,), default_initializer = initializers .get ("zeros" )
149+ self .activation_branch = self .activation_trunk = activations .get (activation )
150+ self .kernel_initializer = kernel_initializer
151+
152+ self .num_outputs = num_outputs
153+ if self .num_outputs == 1 :
154+ if multi_output_strategy is not None :
155+ raise ValueError (
156+ "num_outputs is set to 1, but multi_output_strategy is not None."
157+ )
158+ elif multi_output_strategy is None :
159+ multi_output_strategy = "independent"
160+ print (
161+ f"Warning: There are { num_outputs } outputs, but no multi_output_strategy selected. "
162+ 'Use "independent" as the multi_output_strategy.'
163+ )
164+ self .multi_output_strategy = {
165+ None : SingleOutputStrategy ,
166+ "independent" : IndependentStrategy ,
167+ "split_both" : SplitBothStrategy ,
168+ "split_branch" : SplitBranchStrategy ,
169+ "split_trunk" : SplitTrunkStrategy ,
170+ }[multi_output_strategy ](self )
171+
172+ self .branch , self .trunk = self .multi_output_strategy .build (
173+ layer_sizes_branch , layer_sizes_trunk
174+ )
175+ if isinstance (self .branch , list ):
176+ self .branch = paddle .nn .LayerList (self .branch )
177+ if isinstance (self .trunk , list ):
178+ self .trunk = paddle .nn .LayerList (self .trunk )
179+ self .b = paddle .nn .ParameterList (
180+ [
181+ paddle .create_parameter (
182+ shape = [1 ,],
183+ dtype = paddle .get_default_dtype (),
184+ default_initializer = paddle .nn .initializer .Constant (value = 0 ),
185+ )
186+ for _ in range (self .num_outputs )
187+ ]
126188 )
127- self .regularizer = regularization
189+
190+ def build_branch_net (self , layer_sizes_branch ):
191+ # User-defined network
192+ if callable (layer_sizes_branch [1 ]):
193+ return layer_sizes_branch [1 ]
194+ # Fully connected network
195+ return FNN (layer_sizes_branch , self .activation_branch , self .kernel_initializer )
196+
197+ def build_trunk_net (self , layer_sizes_trunk ):
198+ return FNN (layer_sizes_trunk , self .activation_trunk , self .kernel_initializer )
199+
200+ def merge_branch_trunk (self , x_func , x_loc , index ):
201+ y = x_func @ x_loc .T
202+ y += self .b [index ]
203+ return y
204+
205+ @staticmethod
206+ def concatenate_outputs (ys ):
207+ return paddle .stack (ys , axis = 2 )
128208
129209 def forward (self , inputs ):
130210 x_func = inputs [0 ]
131211 x_loc = inputs [1 ]
132- # Branch net to encode the input function
133- x_func = self .branch (x_func )
134- # Trunk net to encode the domain of the output function
212+ # Trunk net input transform
135213 if self ._input_transform is not None :
136214 x_loc = self ._input_transform (x_loc )
137- x_loc = self .activation_trunk (self .trunk (x_loc ))
138- # Dot product
139- if x_func .shape [- 1 ] != x_loc .shape [- 1 ]:
140- raise AssertionError (
141- "Output sizes of branch net and trunk net do not match."
142- )
143- x = x_func @ x_loc .T
144- # Add bias
145- x += self .b
146-
215+ x = self .multi_output_strategy .call (x_func , x_loc )
147216 if self ._output_transform is not None :
148217 x = self ._output_transform (inputs , x )
149218 return x
0 commit comments