@@ -68,6 +68,7 @@ def __init__(
6868 kernel_initializer ,
6969 num_outputs = 1 ,
7070 multi_output_strategy = None ,
71+ regularization = None ,
7172 ):
7273 super ().__init__ ()
7374 if isinstance (activation , dict ):
@@ -76,6 +77,7 @@ def __init__(
7677 else :
7778 self .activation_branch = self .activation_trunk = activations .get (activation )
7879 self .kernel_initializer = kernel_initializer
80+ self .regularizer = regularization
7981
8082 self .num_outputs = num_outputs
8183 if self .num_outputs == 1 :
@@ -190,6 +192,7 @@ def __init__(
190192 kernel_initializer ,
191193 num_outputs = 1 ,
192194 multi_output_strategy = None ,
195+ regularization = None ,
193196 ):
194197 super ().__init__ ()
195198 if isinstance (activation , dict ):
@@ -198,6 +201,7 @@ def __init__(
198201 else :
199202 self .activation_branch = self .activation_trunk = activations .get (activation )
200203 self .kernel_initializer = kernel_initializer
204+ self .regularizer = regularization
201205
202206 self .num_outputs = num_outputs
203207 if self .num_outputs == 1 :
@@ -295,7 +299,7 @@ def __init__(
295299 regularization = None ,
296300 ):
297301 super ().__init__ ()
298- self .regularization = regularization # TODO: currently unused
302+ self .regularizer = regularization
299303 self .pod_basis = torch .as_tensor (pod_basis , dtype = torch .float32 )
300304 if isinstance (activation , dict ):
301305 activation_branch = activation ["branch" ]
0 commit comments