@@ -321,54 +321,48 @@ def build_branch_net(self):
321321 y_func = self .X_func
322322 if callable (self .layer_size_func [1 ]):
323323 # User-defined network
324- y_func = self .layer_size_func [1 ](y_func )
325- elif self .stacked :
326- # Stacked fully connected network
327- stack_size = self .layer_size_func [- 1 ]
328- for i in range (1 , len (self .layer_size_func ) - 1 ):
329- y_func = self ._stacked_dense (
330- y_func ,
331- self .layer_size_func [i ],
332- stack_size ,
333- activation = self .activation_branch ,
324+ return self .layer_size_func [1 ](y_func )
325+
326+ def _add_branch_layer (
327+ inputs , units , stack_size = None , activation = None , use_bias = True
328+ ):
329+ if stack_size is None :
330+ return self ._dense (
331+ inputs ,
332+ units ,
333+ activation = activation ,
334+ regularizer = self .regularizer ,
334335 trainable = self .trainable_branch ,
336+ use_bias = use_bias ,
335337 )
336- if self .dropout_rate_branch [i - 1 ] > 0 :
337- y_func = tf .layers .dropout (
338- y_func ,
339- rate = self .dropout_rate_branch [i - 1 ],
340- training = self .training ,
341- )
342- y_func = self ._stacked_dense (
343- y_func ,
344- 1 ,
338+ return self ._stacked_dense (
339+ inputs ,
340+ units ,
345341 stack_size ,
346- use_bias = self . use_bias ,
342+ activation = activation ,
347343 trainable = self .trainable_branch ,
344+ use_bias = use_bias ,
348345 )
349- else :
350- # Unstacked fully connected network
351- for i in range (1 , len (self .layer_size_func ) - 1 ):
352- y_func = self ._dense (
353- y_func ,
354- self .layer_size_func [i ],
355- activation = self .activation_branch ,
356- regularizer = self .regularizer ,
357- trainable = self .trainable_branch ,
358- )
359- if self .dropout_rate_branch [i - 1 ] > 0 :
360- y_func = tf .layers .dropout (
361- y_func ,
362- rate = self .dropout_rate_branch [i - 1 ],
363- training = self .training ,
364- )
365- y_func = self ._dense (
346+
347+ for i in range (1 , len (self .layer_size_func ) - 1 ):
348+ y_func = _add_branch_layer (
366349 y_func ,
367- self .layer_size_func [- 1 ],
368- use_bias = self .use_bias ,
369- regularizer = self .regularizer ,
370- trainable = self .trainable_branch ,
350+ self .layer_size_func [i ],
351+ self .layer_size_func [- 1 ] if self .stacked else None ,
352+ activation = self .activation_branch ,
371353 )
354+ if self .dropout_rate_branch [i - 1 ] > 0 :
355+ y_func = tf .layers .dropout (
356+ y_func ,
357+ rate = self .dropout_rate_branch [i - 1 ],
358+ training = self .training ,
359+ )
360+ y_func = _add_branch_layer (
361+ y_func ,
362+ 1 if self .stacked else self .layer_size_func [- 1 ],
363+ self .layer_size_func [- 1 ] if self .stacked else None ,
364+ use_bias = self .use_bias ,
365+ )
372366 return y_func
373367
374368 def build_trunk_net (self ):
0 commit comments