@@ -60,24 +60,46 @@ class PFNN(NN):
6060 (how the layers are connected). If `layer_sizes[i]` is an int, it represents
6161 one layer shared by all the outputs; if `layer_sizes[i]` is a list, it
6262 represents `len(layer_sizes[i])` sub-layers, each of which is exclusively
63- used by one output. Note that `len(layer_sizes[i])` should equal the number
64- of outputs. Every number specifies the number of neurons in that layer.
63+ used by one output. Every list in `layer_sizes` must have the same length
64+ (= number of subnetworks). If the last element of `layer_sizes` is an int
65+ preceded by a list, it must be equal to the number of subnetworks: all
66+ subnetworks have an output size of 1 and are then concatenated. If the last
67+ element is a list, it specifies the output size for each subnetwork before
68+ concatenation.
69+ activation: Activation function.
70+ kernel_initializer: Initializer for the kernel weights.
6571 """
6672
67- def __init__ (self , layer_sizes , activation , kernel_initializer ):
73+ def __init__ (self , layer_sizes , activation , kernel_initializer , regularization = None ):
6874 super ().__init__ ()
6975 self .activation = activations .get (activation )
7076 initializer = initializers .get (kernel_initializer )
7177 initializer_zero = initializers .get ("zeros" )
78+ self .regularizer = regularization
7279
7380 if len (layer_sizes ) <= 1 :
7481 raise ValueError ("must specify input and output sizes" )
7582 if not isinstance (layer_sizes [0 ], int ):
7683 raise ValueError ("input size must be integer" )
77- if not isinstance (layer_sizes [- 1 ], int ):
78- raise ValueError ("output size must be integer" )
7984
80- n_output = layer_sizes [- 1 ]
85+ # Determine the number of subnetworks from the first list layer
86+ list_layers = [layer for layer in layer_sizes if isinstance (layer , (list , tuple ))]
87+ if not list_layers :
88+ raise ValueError ("No list layers found; use FNN instead of PFNN for single subnetwork." )
89+ n_subnetworks = len (list_layers [0 ])
90+ for layer in list_layers :
91+ if len (layer ) != n_subnetworks :
92+ raise ValueError ("All list layers must have the same length as the first list layer." )
93+
94+ # Validate output layer if preceded by a list layer
95+ if (
96+ isinstance (layer_sizes [- 1 ], int )
97+ and isinstance (layer_sizes [- 2 ], (list , tuple ))
98+ and layer_sizes [- 1 ] != n_subnetworks
99+ ):
100+ raise ValueError (
101+ "If last layer is an int and previous is a list, the int must equal the number of subnetworks."
102+ )
81103
82104 def make_linear (n_input , n_output ):
83105 linear = torch .nn .Linear (n_input , n_output , dtype = config .real (torch ))
@@ -86,49 +108,50 @@ def make_linear(n_input, n_output):
86108 return linear
87109
88110 self .layers = torch .nn .ModuleList ()
111+
112+ # Process hidden layers (excluding the output layer)
89113 for i in range (1 , len (layer_sizes ) - 1 ):
90- prev_layer_size = layer_sizes [i - 1 ]
91- curr_layer_size = layer_sizes [i ]
92- if isinstance (curr_layer_size , (list , tuple )):
93- if len (curr_layer_size ) != n_output :
94- raise ValueError (
95- "number of sub-layers should equal number of network outputs"
96- )
97- if isinstance (prev_layer_size , (list , tuple )):
98- # e.g. [8, 8, 8] -> [16, 16, 16]
99- self .layers .append (
100- torch .nn .ModuleList (
101- [
102- make_linear (prev_layer_size [j ], curr_layer_size [j ])
103- for j in range (n_output )
104- ]
105- )
106- )
107- else : # e.g. 64 -> [8, 8, 8]
108- self .layers .append (
109- torch .nn .ModuleList (
110- [
111- make_linear (prev_layer_size , curr_layer_size [j ])
112- for j in range (n_output )
113- ]
114- )
115- )
116- else : # e.g. 64 -> 64
117- if not isinstance (prev_layer_size , int ):
118- raise ValueError (
119- "cannot rejoin parallel subnetworks after splitting"
120- )
121- self .layers .append (make_linear (prev_layer_size , curr_layer_size ))
122-
123- # output layers
124- if isinstance (layer_sizes [- 2 ], (list , tuple )): # e.g. [3, 3, 3] -> 3
125- self .layers .append (
126- torch .nn .ModuleList (
127- [make_linear (layer_sizes [- 2 ][j ], 1 ) for j in range (n_output )]
128- )
129- )
114+ prev_layer = layer_sizes [i - 1 ]
115+ curr_layer = layer_sizes [i ]
116+
117+ if isinstance (curr_layer , (list , tuple )):
118+ # Parallel layer
119+ if isinstance (prev_layer , (list , tuple )):
120+ # Previous is parallel: each subnetwork input is previous subnetwork output
121+ sub_layers = [make_linear (prev_layer [j ], curr_layer [j ]) for j in range (n_subnetworks )]
122+ else :
123+ # Previous is shared: all subnetworks take the same input
124+ sub_layers = [make_linear (prev_layer , curr_layer [j ]) for j in range (n_subnetworks )]
125+ self .layers .append (torch .nn .ModuleList (sub_layers ))
126+ else :
127+ # Shared layer
128+ if isinstance (prev_layer , (list , tuple )):
129+ # Previous is parallel: concatenate outputs
130+ input_size = sum (prev_layer )
131+ else :
132+ input_size = prev_layer
133+ self .layers .append (make_linear (input_size , curr_layer ))
134+
135+ # Process output layer
136+ prev_output_layer = layer_sizes [- 2 ]
137+ output_layer = layer_sizes [- 1 ]
138+
139+ if isinstance (output_layer , (list , tuple )):
140+ if isinstance (prev_output_layer , (list , tuple )):
141+ # Each subnetwork input is corresponding previous output
142+ output_layers = [make_linear (prev_output_layer [j ], output_layer [j ]) for j in range (n_subnetworks )]
143+ else :
144+ # All subnetworks take the same shared input
145+ output_layers = [make_linear (prev_output_layer , output_layer [j ]) for j in range (n_subnetworks )]
146+ self .layers .append (torch .nn .ModuleList (output_layers ))
130147 else :
131- self .layers .append (make_linear (layer_sizes [- 2 ], n_output ))
148+ if isinstance (prev_output_layer , (list , tuple )):
149+ # Each subnetwork outputs 1 and concatenates to output_layer size
150+ output_layers = [make_linear (prev_output_layer [j ], 1 ) for j in range (n_subnetworks )]
151+ self .layers .append (torch .nn .ModuleList (output_layers ))
152+ else :
153+ # Shared output layer
154+ self .layers .append (make_linear (prev_output_layer , output_layer ))
132155
133156 def forward (self , inputs ):
134157 x = inputs
@@ -137,14 +160,18 @@ def forward(self, inputs):
137160
138161 for layer in self .layers [:- 1 ]:
139162 if isinstance (layer , torch .nn .ModuleList ):
163+ # Parallel layer processing
140164 if isinstance (x , list ):
141165 x = [self .activation (f (x_ )) for f , x_ in zip (layer , x )]
142166 else :
143167 x = [self .activation (f (x )) for f in layer ]
144168 else :
169+ # Shared layer processing (concatenate if necessary)
170+ if isinstance (x , list ):
171+ x = torch .cat (x , dim = 1 )
145172 x = self .activation (layer (x ))
146173
147- # output layers
174+ # Output layer processing
148175 if isinstance (x , list ):
149176 x = torch .cat ([f (x_ ) for f , x_ in zip (self .layers [- 1 ], x )], dim = 1 )
150177 else :
0 commit comments