@@ -235,6 +235,7 @@ def __init__(
235235 kernel_initializer ,
236236 num_outputs = 1 ,
237237 multi_output_strategy = None ,
238+ regularization = None ,
238239 ):
239240 super ().__init__ ()
240241 if isinstance (activation , dict ):
@@ -243,6 +244,7 @@ def __init__(
243244 else :
244245 self .activation_branch = self .activation_trunk = activations .get (activation )
245246 self .kernel_initializer = kernel_initializer
247+ self .regularization = regularization
246248
247249 self .num_outputs = num_outputs
248250 if self .num_outputs == 1 :
@@ -280,10 +282,20 @@ def build_branch_net(self, layer_sizes_branch):
280282 if callable (layer_sizes_branch [1 ]):
281283 return layer_sizes_branch [1 ]
282284 # Fully connected network
283- return FNN (layer_sizes_branch , self .activation_branch , self .kernel_initializer )
285+ return FNN (
286+ layer_sizes_branch ,
287+ self .activation_branch ,
288+ self .kernel_initializer ,
289+ regularization = self .regularization ,
290+ )
284291
285292 def build_trunk_net (self , layer_sizes_trunk ):
286- return FNN (layer_sizes_trunk , self .activation_trunk , self .kernel_initializer )
293+ return FNN (
294+ layer_sizes_trunk ,
295+ self .activation_trunk ,
296+ self .kernel_initializer ,
297+ regularization = self .regularization ,
298+ )
287299
288300 def merge_branch_trunk (self , x_func , x_loc ):
289301 y = tf .einsum ("bi,bi->b" , x_func , x_loc )
0 commit comments