@@ -321,7 +321,7 @@ def build_branch_net(self):
321321 if callable (self .layer_size_func [1 ]):
322322 # User-defined network
323323 return self .layer_size_func [1 ](self .X_func )
324-
324+
325325 if self .stacked :
326326 # Stacked fully connected network
327327 return self ._build_stacked_branch_net ()
@@ -422,15 +422,18 @@ def _dense(
422422 regularizer = None ,
423423 trainable = True ,
424424 ):
425- return tf .layers .dense (
426- inputs ,
425+ dense = tf .keras .layers .Dense (
427426 units ,
428427 activation = activation ,
429428 use_bias = use_bias ,
430429 kernel_initializer = self .kernel_initializer ,
431430 kernel_regularizer = regularizer ,
432431 trainable = trainable ,
433432 )
433+ out = dense (inputs )
434+ if regularizer :
435+ self .regularization_loss += tf .math .add_n (dense .losses )
436+ return out
434437
435438 def _stacked_dense (
436439 self , inputs , units , stack_size , activation = None , use_bias = True , trainable = True
@@ -637,24 +640,22 @@ def build_branch_net(self):
637640 else :
638641 # Fully connected network
639642 for i in range (1 , len (self .layer_size_func ) - 1 ):
640- y_func = tf . layers . dense (
643+ y_func = self . _dense (
641644 y_func ,
642645 self .layer_size_func [i ],
643646 activation = self .activation_branch ,
644- kernel_initializer = self .kernel_initializer ,
645- kernel_regularizer = self .regularizer ,
647+ regularizer = self .regularizer ,
646648 )
647649 if self .dropout_rate_branch [i - 1 ] > 0 :
648650 y_func = tf .layers .dropout (
649651 y_func ,
650652 rate = self .dropout_rate_branch [i - 1 ],
651653 training = self .training ,
652654 )
653- y_func = tf . layers . dense (
655+ y_func = self . _dense (
654656 y_func ,
655657 self .layer_size_func [- 1 ],
656- kernel_initializer = self .kernel_initializer ,
657- kernel_regularizer = self .regularizer ,
658+ regularizer = self .regularizer ,
658659 )
659660 return y_func
660661
@@ -664,12 +665,11 @@ def build_trunk_net(self):
664665 if self ._input_transform is not None :
665666 y_loc = self ._input_transform (y_loc )
666667 for i in range (1 , len (self .layer_size_loc )):
667- y_loc = tf . layers . dense (
668+ y_loc = self . _dense (
668669 y_loc ,
669670 self .layer_size_loc [i ],
670671 activation = self .activation_trunk ,
671- kernel_initializer = self .kernel_initializer ,
672- kernel_regularizer = self .regularizer ,
672+ regularizer = self .regularizer ,
673673 )
674674 if self .dropout_rate_trunk [i - 1 ] > 0 :
675675 y_loc = tf .layers .dropout (
@@ -687,3 +687,25 @@ def merge_branch_trunk(self, branch, trunk):
687687 @staticmethod
688688 def concatenate_outputs (ys ):
689689 return tf .stack (ys , axis = 2 )
690+
691+ def _dense (
692+ self ,
693+ inputs ,
694+ units ,
695+ activation = None ,
696+ use_bias = True ,
697+ regularizer = None ,
698+ trainable = True ,
699+ ):
700+ dense = tf .keras .layers .Dense (
701+ units ,
702+ activation = activation ,
703+ use_bias = use_bias ,
704+ kernel_initializer = self .kernel_initializer ,
705+ kernel_regularizer = regularizer ,
706+ trainable = trainable ,
707+ )
708+ out = dense (inputs )
709+ if regularizer :
710+ self .regularization_loss += tf .math .add_n (dense .losses )
711+ return out
0 commit comments