Skip to content

Commit d6abac9

Browse files
committed
add pfnn features
1 parent bb7ddb4 commit d6abac9

File tree

1 file changed

+75
-48
lines changed

1 file changed

+75
-48
lines changed

deepxde/nn/pytorch/fnn.py

Lines changed: 75 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -60,24 +60,46 @@ class PFNN(NN):
6060
(how the layers are connected). If `layer_sizes[i]` is an int, it represents
6161
one layer shared by all the outputs; if `layer_sizes[i]` is a list, it
6262
represents `len(layer_sizes[i])` sub-layers, each of which is exclusively
63-
used by one output. Note that `len(layer_sizes[i])` should equal the number
64-
of outputs. Every number specifies the number of neurons in that layer.
63+
used by one output. Every list in `layer_sizes` must have the same length
64+
(= number of subnetworks). If the last element of `layer_sizes` is an int
65+
preceded by a list, it must be equal to the number of subnetworks: all
66+
subnetworks have an output size of 1 and are then concatenated. If the last
67+
element is a list, it specifies the output size for each subnetwork before
68+
concatenation.
69+
activation: Activation function.
70+
kernel_initializer: Initializer for the kernel weights.
6571
"""
6672

67-
def __init__(self, layer_sizes, activation, kernel_initializer):
73+
def __init__(self, layer_sizes, activation, kernel_initializer, regularization=None):
6874
super().__init__()
6975
self.activation = activations.get(activation)
7076
initializer = initializers.get(kernel_initializer)
7177
initializer_zero = initializers.get("zeros")
78+
self.regularizer = regularization
7279

7380
if len(layer_sizes) <= 1:
7481
raise ValueError("must specify input and output sizes")
7582
if not isinstance(layer_sizes[0], int):
7683
raise ValueError("input size must be integer")
77-
if not isinstance(layer_sizes[-1], int):
78-
raise ValueError("output size must be integer")
7984

80-
n_output = layer_sizes[-1]
85+
# Determine the number of subnetworks from the first list layer
86+
list_layers = [layer for layer in layer_sizes if isinstance(layer, (list, tuple))]
87+
if not list_layers:
88+
raise ValueError("No list layers found; use FNN instead of PFNN for single subnetwork.")
89+
n_subnetworks = len(list_layers[0])
90+
for layer in list_layers:
91+
if len(layer) != n_subnetworks:
92+
raise ValueError("All list layers must have the same length as the first list layer.")
93+
94+
# Validate output layer if preceded by a list layer
95+
if (
96+
isinstance(layer_sizes[-1], int)
97+
and isinstance(layer_sizes[-2], (list, tuple))
98+
and layer_sizes[-1] != n_subnetworks
99+
):
100+
raise ValueError(
101+
"If last layer is an int and previous is a list, the int must equal the number of subnetworks."
102+
)
81103

82104
def make_linear(n_input, n_output):
83105
linear = torch.nn.Linear(n_input, n_output, dtype=config.real(torch))
@@ -86,49 +108,50 @@ def make_linear(n_input, n_output):
86108
return linear
87109

88110
self.layers = torch.nn.ModuleList()
111+
112+
# Process hidden layers (excluding the output layer)
89113
for i in range(1, len(layer_sizes) - 1):
90-
prev_layer_size = layer_sizes[i - 1]
91-
curr_layer_size = layer_sizes[i]
92-
if isinstance(curr_layer_size, (list, tuple)):
93-
if len(curr_layer_size) != n_output:
94-
raise ValueError(
95-
"number of sub-layers should equal number of network outputs"
96-
)
97-
if isinstance(prev_layer_size, (list, tuple)):
98-
# e.g. [8, 8, 8] -> [16, 16, 16]
99-
self.layers.append(
100-
torch.nn.ModuleList(
101-
[
102-
make_linear(prev_layer_size[j], curr_layer_size[j])
103-
for j in range(n_output)
104-
]
105-
)
106-
)
107-
else: # e.g. 64 -> [8, 8, 8]
108-
self.layers.append(
109-
torch.nn.ModuleList(
110-
[
111-
make_linear(prev_layer_size, curr_layer_size[j])
112-
for j in range(n_output)
113-
]
114-
)
115-
)
116-
else: # e.g. 64 -> 64
117-
if not isinstance(prev_layer_size, int):
118-
raise ValueError(
119-
"cannot rejoin parallel subnetworks after splitting"
120-
)
121-
self.layers.append(make_linear(prev_layer_size, curr_layer_size))
122-
123-
# output layers
124-
if isinstance(layer_sizes[-2], (list, tuple)): # e.g. [3, 3, 3] -> 3
125-
self.layers.append(
126-
torch.nn.ModuleList(
127-
[make_linear(layer_sizes[-2][j], 1) for j in range(n_output)]
128-
)
129-
)
114+
prev_layer = layer_sizes[i - 1]
115+
curr_layer = layer_sizes[i]
116+
117+
if isinstance(curr_layer, (list, tuple)):
118+
# Parallel layer
119+
if isinstance(prev_layer, (list, tuple)):
120+
# Previous is parallel: each subnetwork input is previous subnetwork output
121+
sub_layers = [make_linear(prev_layer[j], curr_layer[j]) for j in range(n_subnetworks)]
122+
else:
123+
# Previous is shared: all subnetworks take the same input
124+
sub_layers = [make_linear(prev_layer, curr_layer[j]) for j in range(n_subnetworks)]
125+
self.layers.append(torch.nn.ModuleList(sub_layers))
126+
else:
127+
# Shared layer
128+
if isinstance(prev_layer, (list, tuple)):
129+
# Previous is parallel: concatenate outputs
130+
input_size = sum(prev_layer)
131+
else:
132+
input_size = prev_layer
133+
self.layers.append(make_linear(input_size, curr_layer))
134+
135+
# Process output layer
136+
prev_output_layer = layer_sizes[-2]
137+
output_layer = layer_sizes[-1]
138+
139+
if isinstance(output_layer, (list, tuple)):
140+
if isinstance(prev_output_layer, (list, tuple)):
141+
# Each subnetwork input is corresponding previous output
142+
output_layers = [make_linear(prev_output_layer[j], output_layer[j]) for j in range(n_subnetworks)]
143+
else:
144+
# All subnetworks take the same shared input
145+
output_layers = [make_linear(prev_output_layer, output_layer[j]) for j in range(n_subnetworks)]
146+
self.layers.append(torch.nn.ModuleList(output_layers))
130147
else:
131-
self.layers.append(make_linear(layer_sizes[-2], n_output))
148+
if isinstance(prev_output_layer, (list, tuple)):
149+
# Each subnetwork outputs 1 and concatenates to output_layer size
150+
output_layers = [make_linear(prev_output_layer[j], 1) for j in range(n_subnetworks)]
151+
self.layers.append(torch.nn.ModuleList(output_layers))
152+
else:
153+
# Shared output layer
154+
self.layers.append(make_linear(prev_output_layer, output_layer))
132155

133156
def forward(self, inputs):
134157
x = inputs
@@ -137,14 +160,18 @@ def forward(self, inputs):
137160

138161
for layer in self.layers[:-1]:
139162
if isinstance(layer, torch.nn.ModuleList):
163+
# Parallel layer processing
140164
if isinstance(x, list):
141165
x = [self.activation(f(x_)) for f, x_ in zip(layer, x)]
142166
else:
143167
x = [self.activation(f(x)) for f in layer]
144168
else:
169+
# Shared layer processing (concatenate if necessary)
170+
if isinstance(x, list):
171+
x = torch.cat(x, dim=1)
145172
x = self.activation(layer(x))
146173

147-
# output layers
174+
# Output layer processing
148175
if isinstance(x, list):
149176
x = torch.cat([f(x_) for f, x_ in zip(self.layers[-1], x)], dim=1)
150177
else:

0 commit comments

Comments
 (0)