@@ -422,14 +422,18 @@ def _dense(
422422 regularizer = None ,
423423 trainable = True ,
424424 ):
425- return tf .keras .layers .Dense (
425+ dense = tf .keras .layers .Dense (
426426 units ,
427427 activation = activation ,
428428 use_bias = use_bias ,
429429 kernel_initializer = self .kernel_initializer ,
430430 kernel_regularizer = regularizer ,
431431 trainable = trainable ,
432- )(inputs )
432+ )
433+ out = dense (inputs )
434+ if regularizer :
435+ self .regularization_loss += tf .math .add_n (dense .losses )
436+ return out
433437
434438 def _stacked_dense (
435439 self , inputs , units , stack_size , activation = None , use_bias = True , trainable = True
@@ -636,23 +640,23 @@ def build_branch_net(self):
636640 else :
637641 # Fully connected network
638642 for i in range (1 , len (self .layer_size_func ) - 1 ):
639- y_func = tf .keras .layers .Dense (
643+ y_func = self ._dense (
644+ y_func ,
640645 self .layer_size_func [i ],
641646 activation = self .activation_branch ,
642- kernel_initializer = self .kernel_initializer ,
643- kernel_regularizer = self .regularizer ,
644- )(y_func )
647+ regularizer = self .regularizer ,
648+ )
645649 if self .dropout_rate_branch [i - 1 ] > 0 :
646650 y_func = tf .layers .dropout (
647651 y_func ,
648652 rate = self .dropout_rate_branch [i - 1 ],
649653 training = self .training ,
650654 )
651- y_func = tf .keras .layers .Dense (
655+ y_func = self ._dense (
656+ y_func ,
652657 self .layer_size_func [- 1 ],
653- kernel_initializer = self .kernel_initializer ,
654- kernel_regularizer = self .regularizer ,
655- )(y_func )
658+ regularizer = self .regularizer ,
659+ )
656660 return y_func
657661
658662 def build_trunk_net (self ):
@@ -661,12 +665,12 @@ def build_trunk_net(self):
661665 if self ._input_transform is not None :
662666 y_loc = self ._input_transform (y_loc )
663667 for i in range (1 , len (self .layer_size_loc )):
664- y_loc = tf .keras .layers .Dense (
668+ y_loc = self ._dense (
669+ y_loc ,
665670 self .layer_size_loc [i ],
666671 activation = self .activation_trunk ,
667- kernel_initializer = self .kernel_initializer ,
668- kernel_regularizer = self .regularizer ,
669- )(y_loc )
672+ regularizer = self .regularizer ,
673+ )
670674 if self .dropout_rate_trunk [i - 1 ] > 0 :
671675 y_loc = tf .layers .dropout (
672676 y_loc , rate = self .dropout_rate_trunk [i - 1 ], training = self .training
@@ -683,3 +687,25 @@ def merge_branch_trunk(self, branch, trunk):
683687 @staticmethod
684688 def concatenate_outputs (ys ):
685689 return tf .stack (ys , axis = 2 )
690+
691+ def _dense (
692+ self ,
693+ inputs ,
694+ units ,
695+ activation = None ,
696+ use_bias = True ,
697+ regularizer = None ,
698+ trainable = True ,
699+ ):
700+ dense = tf .keras .layers .Dense (
701+ units ,
702+ activation = activation ,
703+ use_bias = use_bias ,
704+ kernel_initializer = self .kernel_initializer ,
705+ kernel_regularizer = regularizer ,
706+ trainable = trainable ,
707+ )
708+ out = dense (inputs )
709+ if regularizer :
710+ self .regularization_loss += tf .math .add_n (dense .losses )
711+ return out
0 commit comments