@@ -58,13 +58,21 @@ def build(self):
5858 self ._inputs = [self .X_func1 , self .X_func2 , self .X_loc ]
5959
6060 # Branch net 1
61- y_func1 = self ._net (
62- self .X_func1 , self .layer_branch1 [1 :], self .activation_branch1
63- )
61+ if callable (self .layer_branch1 [1 ]):
62+ # User-defined network
63+ y_func1 = self .layer_branch1 [1 ](self .X_func1 )
64+ else :
65+ y_func1 = self ._net (
66+ self .X_func1 , self .layer_branch1 [1 :], self .activation_branch1
67+ )
6468 # Branch net 2
65- y_func2 = self ._net (
66- self .X_func2 , self .layer_branch2 [1 :], self .activation_branch2
67- )
69+ if callable (self .layer_branch2 [1 ]):
70+ # User-defined network
71+ y_func2 = self .layer_branch2 [1 ](self .X_func2 )
72+ else :
73+ y_func2 = self ._net (
74+ self .X_func2 , self .layer_branch2 [1 :], self .activation_branch2
75+ )
6876 # Trunk net
6977 y_loc = self ._net (self .X_loc , self .layer_trunk [1 :], self .activation_trunk )
7078
@@ -103,13 +111,21 @@ def build(self):
103111 self ._inputs = [self .X_func1 , self .X_func2 , self .X_loc ]
104112
105113 # Branch net 1
106- y_func1 = self ._net (
107- self .X_func1 , self .layer_branch1 [1 :], self .activation_branch1
108- )
114+ if callable (self .layer_branch1 [1 ]):
115+ # User-defined network
116+ y_func1 = self .layer_branch1 [1 ](self .X_func1 )
117+ else :
118+ y_func1 = self ._net (
119+ self .X_func1 , self .layer_branch1 [1 :], self .activation_branch1
120+ )
109121 # Branch net 2
110- y_func2 = self ._net (
111- self .X_func2 , self .layer_branch2 [1 :], self .activation_branch2
112- )
122+ if callable (self .layer_branch2 [1 ]):
123+ # User-defined network
124+ y_func2 = self .layer_branch2 [1 ](self .X_func2 )
125+ else :
126+ y_func2 = self ._net (
127+ self .X_func2 , self .layer_branch2 [1 :], self .activation_branch2
128+ )
113129 # Trunk net
114130 y_loc = self ._net (self .X_loc , self .layer_trunk [1 :], self .activation_trunk )
115131
0 commit comments