@@ -70,7 +70,9 @@ class PFNN(NN):
7070 kernel_initializer: Initializer for the kernel weights.
7171 """
7272
73- def __init__ (self , layer_sizes , activation , kernel_initializer , regularization = None ):
73+ def __init__ (
74+ self , layer_sizes , activation , kernel_initializer , regularization = None
75+ ):
7476 super ().__init__ ()
7577 self .activation = activations .get (activation )
7678 initializer = initializers .get (kernel_initializer )
@@ -83,13 +85,19 @@ def __init__(self, layer_sizes, activation, kernel_initializer, regularization=N
8385 raise ValueError ("input size must be integer" )
8486
8587 # Determine the number of subnetworks from the first list layer
86- list_layers = [layer for layer in layer_sizes if isinstance (layer , (list , tuple ))]
88+ list_layers = [
89+ layer for layer in layer_sizes if isinstance (layer , (list , tuple ))
90+ ]
8791 if not list_layers :
88- raise ValueError ("No list layers found; use FNN instead of PFNN for single subnetwork." )
92+ raise ValueError (
93+ "No list layers found; use FNN instead of PFNN for single subnetwork."
94+ )
8995 n_subnetworks = len (list_layers [0 ])
9096 for layer in list_layers :
9197 if len (layer ) != n_subnetworks :
92- raise ValueError ("All list layers must have the same length as the first list layer." )
98+ raise ValueError (
99+ "All list layers must have the same length as the first list layer."
100+ )
93101
94102 # Validate output layer if preceded by a list layer
95103 if (
@@ -118,10 +126,16 @@ def make_linear(n_input, n_output):
118126 # Parallel layer
119127 if isinstance (prev_layer , (list , tuple )):
120128 # Previous is parallel: each subnetwork input is previous subnetwork output
121- sub_layers = [make_linear (prev_layer [j ], curr_layer [j ]) for j in range (n_subnetworks )]
129+ sub_layers = [
130+ make_linear (prev_layer [j ], curr_layer [j ])
131+ for j in range (n_subnetworks )
132+ ]
122133 else :
123134 # Previous is shared: all subnetworks take the same input
124- sub_layers = [make_linear (prev_layer , curr_layer [j ]) for j in range (n_subnetworks )]
135+ sub_layers = [
136+ make_linear (prev_layer , curr_layer [j ])
137+ for j in range (n_subnetworks )
138+ ]
125139 self .layers .append (torch .nn .ModuleList (sub_layers ))
126140 else :
127141 # Shared layer
@@ -139,15 +153,23 @@ def make_linear(n_input, n_output):
139153 if isinstance (output_layer , (list , tuple )):
140154 if isinstance (prev_output_layer , (list , tuple )):
141155 # Each subnetwork input is corresponding previous output
142- output_layers = [make_linear (prev_output_layer [j ], output_layer [j ]) for j in range (n_subnetworks )]
156+ output_layers = [
157+ make_linear (prev_output_layer [j ], output_layer [j ])
158+ for j in range (n_subnetworks )
159+ ]
143160 else :
144161 # All subnetworks take the same shared input
145- output_layers = [make_linear (prev_output_layer , output_layer [j ]) for j in range (n_subnetworks )]
162+ output_layers = [
163+ make_linear (prev_output_layer , output_layer [j ])
164+ for j in range (n_subnetworks )
165+ ]
146166 self .layers .append (torch .nn .ModuleList (output_layers ))
147167 else :
148168 if isinstance (prev_output_layer , (list , tuple )):
149169 # Each subnetwork outputs 1 and concatenates to output_layer size
150- output_layers = [make_linear (prev_output_layer [j ], 1 ) for j in range (n_subnetworks )]
170+ output_layers = [
171+ make_linear (prev_output_layer [j ], 1 ) for j in range (n_subnetworks )
172+ ]
151173 self .layers .append (torch .nn .ModuleList (output_layers ))
152174 else :
153175 # Shared output layer
0 commit comments