@@ -318,58 +318,67 @@ def build(self):
318318 self .built = True
319319
320320 def build_branch_net (self ):
321- y_func = self .X_func
322321 if callable (self .layer_size_func [1 ]):
323322 # User-defined network
324- y_func = self .layer_size_func [1 ](y_func )
325- elif self .stacked :
323+ return self .layer_size_func [1 ](self .X_func )
324+
325+ if self .stacked :
326326 # Stacked fully connected network
327- stack_size = self .layer_size_func [- 1 ]
328- for i in range (1 , len (self .layer_size_func ) - 1 ):
329- y_func = self ._stacked_dense (
330- y_func ,
331- self .layer_size_func [i ],
332- stack_size ,
333- activation = self .activation_branch ,
334- trainable = self .trainable_branch ,
335- )
336- if self .dropout_rate_branch [i - 1 ] > 0 :
337- y_func = tf .layers .dropout (
338- y_func ,
339- rate = self .dropout_rate_branch [i - 1 ],
340- training = self .training ,
341- )
327+ return self ._build_stacked_branch_net ()
328+
329+ # Unstacked fully connected network
330+ return self ._build_unstacked_branch_net ()
331+
332+ def _build_stacked_branch_net (self ):
333+ y_func = self .X_func
334+ stack_size = self .layer_size_func [- 1 ]
335+
336+ for i in range (1 , len (self .layer_size_func ) - 1 ):
342337 y_func = self ._stacked_dense (
343338 y_func ,
344- 1 ,
345- stack_size ,
346- use_bias = self .use_bias ,
339+ self . layer_size_func [ i ] ,
340+ stack_size = stack_size ,
341+ activation = self .activation_branch ,
347342 trainable = self .trainable_branch ,
348343 )
349- else :
350- # Unstacked fully connected network
351- for i in range (1 , len (self .layer_size_func ) - 1 ):
352- y_func = self ._dense (
344+ if self .dropout_rate_branch [i - 1 ] > 0 :
345+ y_func = tf .layers .dropout (
353346 y_func ,
354- self .layer_size_func [i ],
355- activation = self .activation_branch ,
356- regularizer = self .regularizer ,
357- trainable = self .trainable_branch ,
347+ rate = self .dropout_rate_branch [i - 1 ],
348+ training = self .training ,
358349 )
359- if self .dropout_rate_branch [i - 1 ] > 0 :
360- y_func = tf .layers .dropout (
361- y_func ,
362- rate = self .dropout_rate_branch [i - 1 ],
363- training = self .training ,
364- )
350+ return self ._stacked_dense (
351+ y_func ,
352+ 1 ,
353+ stack_size = stack_size ,
354+ use_bias = self .use_bias ,
355+ trainable = self .trainable_branch ,
356+ )
357+
358+ def _build_unstacked_branch_net (self ):
359+ y_func = self .X_func
360+
361+ for i in range (1 , len (self .layer_size_func ) - 1 ):
365362 y_func = self ._dense (
366363 y_func ,
367- self .layer_size_func [- 1 ],
368- use_bias = self .use_bias ,
364+ self .layer_size_func [i ],
365+ activation = self .activation_branch ,
369366 regularizer = self .regularizer ,
370367 trainable = self .trainable_branch ,
371368 )
372- return y_func
369+ if self .dropout_rate_branch [i - 1 ] > 0 :
370+ y_func = tf .layers .dropout (
371+ y_func ,
372+ rate = self .dropout_rate_branch [i - 1 ],
373+ training = self .training ,
374+ )
375+ return self ._dense (
376+ y_func ,
377+ self .layer_size_func [- 1 ],
378+ use_bias = self .use_bias ,
379+ regularizer = self .regularizer ,
380+ trainable = self .trainable_branch ,
381+ )
373382
374383 def build_trunk_net (self ):
375384 y_loc = self .X_loc
0 commit comments