Skip to content

Commit f7ae7e0

Browse files
authored
Backend PyTorch: Add PFNN features (#1966)
1 parent e99a95b commit f7ae7e0

File tree

1 file changed

+93
-47
lines changed

1 file changed

+93
-47
lines changed

deepxde/nn/pytorch/fnn.py

Lines changed: 93 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,14 @@ class PFNN(NN):
6060
(how the layers are connected). If `layer_sizes[i]` is an int, it represents
6161
one layer shared by all the outputs; if `layer_sizes[i]` is a list, it
6262
represents `len(layer_sizes[i])` sub-layers, each of which is exclusively
63-
used by one output. Note that `len(layer_sizes[i])` should equal the number
64-
of outputs. Every number specifies the number of neurons in that layer.
63+
used by one output. Every list in `layer_sizes` must have the same length
64+
(= number of subnetworks). If the last element of `layer_sizes` is an int
65+
preceded by a list, it must be equal to the number of subnetworks: all
66+
subnetworks have an output size of 1 and are then concatenated. If the last
67+
element is a list, it specifies the output size for each subnetwork before
68+
concatenation.
69+
activation: Activation function.
70+
kernel_initializer: Initializer for the kernel weights.
6571
"""
6672

6773
def __init__(self, layer_sizes, activation, kernel_initializer):
@@ -74,10 +80,31 @@ def __init__(self, layer_sizes, activation, kernel_initializer):
7480
raise ValueError("must specify input and output sizes")
7581
if not isinstance(layer_sizes[0], int):
7682
raise ValueError("input size must be integer")
77-
if not isinstance(layer_sizes[-1], int):
78-
raise ValueError("output size must be integer")
7983

80-
n_output = layer_sizes[-1]
84+
# Determine the number of subnetworks from the first list layer
85+
list_layers = [
86+
layer for layer in layer_sizes if isinstance(layer, (list, tuple))
87+
]
88+
if not list_layers:
89+
raise ValueError(
90+
"No list layers found; use FNN instead of PFNN for single subnetwork."
91+
)
92+
n_subnetworks = len(list_layers[0])
93+
for layer in list_layers:
94+
if len(layer) != n_subnetworks:
95+
raise ValueError(
96+
"All list layers must have the same length as the first list layer."
97+
)
98+
99+
# Validate output layer if preceded by a list layer
100+
if (
101+
isinstance(layer_sizes[-1], int)
102+
and isinstance(layer_sizes[-2], (list, tuple))
103+
and layer_sizes[-1] != n_subnetworks
104+
):
105+
raise ValueError(
106+
"If last layer is an int and previous is a list, the int must equal the number of subnetworks."
107+
)
81108

82109
def make_linear(n_input, n_output):
83110
linear = torch.nn.Linear(n_input, n_output, dtype=config.real(torch))
@@ -86,49 +113,64 @@ def make_linear(n_input, n_output):
86113
return linear
87114

88115
self.layers = torch.nn.ModuleList()
116+
117+
# Process hidden layers (excluding the output layer)
89118
for i in range(1, len(layer_sizes) - 1):
90-
prev_layer_size = layer_sizes[i - 1]
91-
curr_layer_size = layer_sizes[i]
92-
if isinstance(curr_layer_size, (list, tuple)):
93-
if len(curr_layer_size) != n_output:
94-
raise ValueError(
95-
"number of sub-layers should equal number of network outputs"
96-
)
97-
if isinstance(prev_layer_size, (list, tuple)):
98-
# e.g. [8, 8, 8] -> [16, 16, 16]
99-
self.layers.append(
100-
torch.nn.ModuleList(
101-
[
102-
make_linear(prev_layer_size[j], curr_layer_size[j])
103-
for j in range(n_output)
104-
]
105-
)
106-
)
107-
else: # e.g. 64 -> [8, 8, 8]
108-
self.layers.append(
109-
torch.nn.ModuleList(
110-
[
111-
make_linear(prev_layer_size, curr_layer_size[j])
112-
for j in range(n_output)
113-
]
114-
)
115-
)
116-
else: # e.g. 64 -> 64
117-
if not isinstance(prev_layer_size, int):
118-
raise ValueError(
119-
"cannot rejoin parallel subnetworks after splitting"
120-
)
121-
self.layers.append(make_linear(prev_layer_size, curr_layer_size))
122-
123-
# output layers
124-
if isinstance(layer_sizes[-2], (list, tuple)): # e.g. [3, 3, 3] -> 3
125-
self.layers.append(
126-
torch.nn.ModuleList(
127-
[make_linear(layer_sizes[-2][j], 1) for j in range(n_output)]
128-
)
129-
)
119+
prev_layer = layer_sizes[i - 1]
120+
curr_layer = layer_sizes[i]
121+
122+
if isinstance(curr_layer, (list, tuple)):
123+
# Parallel layer
124+
if isinstance(prev_layer, (list, tuple)):
125+
# Previous is parallel: each subnetwork input is previous subnetwork output
126+
sub_layers = [
127+
make_linear(prev_layer[j], curr_layer[j])
128+
for j in range(n_subnetworks)
129+
]
130+
else:
131+
# Previous is shared: all subnetworks take the same input
132+
sub_layers = [
133+
make_linear(prev_layer, curr_layer[j])
134+
for j in range(n_subnetworks)
135+
]
136+
self.layers.append(torch.nn.ModuleList(sub_layers))
137+
else:
138+
# Shared layer
139+
if isinstance(prev_layer, (list, tuple)):
140+
# Previous is parallel: concatenate outputs
141+
input_size = sum(prev_layer)
142+
else:
143+
input_size = prev_layer
144+
self.layers.append(make_linear(input_size, curr_layer))
145+
146+
# Process output layer
147+
prev_output_layer = layer_sizes[-2]
148+
output_layer = layer_sizes[-1]
149+
150+
if isinstance(output_layer, (list, tuple)):
151+
if isinstance(prev_output_layer, (list, tuple)):
152+
# Each subnetwork input is corresponding previous output
153+
output_layers = [
154+
make_linear(prev_output_layer[j], output_layer[j])
155+
for j in range(n_subnetworks)
156+
]
157+
else:
158+
# All subnetworks take the same shared input
159+
output_layers = [
160+
make_linear(prev_output_layer, output_layer[j])
161+
for j in range(n_subnetworks)
162+
]
163+
self.layers.append(torch.nn.ModuleList(output_layers))
130164
else:
131-
self.layers.append(make_linear(layer_sizes[-2], n_output))
165+
if isinstance(prev_output_layer, (list, tuple)):
166+
# Each subnetwork outputs 1 and concatenates to output_layer size
167+
output_layers = [
168+
make_linear(prev_output_layer[j], 1) for j in range(n_subnetworks)
169+
]
170+
self.layers.append(torch.nn.ModuleList(output_layers))
171+
else:
172+
# Shared output layer
173+
self.layers.append(make_linear(prev_output_layer, output_layer))
132174

133175
def forward(self, inputs):
134176
x = inputs
@@ -137,14 +179,18 @@ def forward(self, inputs):
137179

138180
for layer in self.layers[:-1]:
139181
if isinstance(layer, torch.nn.ModuleList):
182+
# Parallel layer processing
140183
if isinstance(x, list):
141184
x = [self.activation(f(x_)) for f, x_ in zip(layer, x)]
142185
else:
143186
x = [self.activation(f(x)) for f in layer]
144187
else:
188+
# Shared layer processing (concatenate if necessary)
189+
if isinstance(x, list):
190+
x = torch.cat(x, dim=1)
145191
x = self.activation(layer(x))
146192

147-
# output layers
193+
# Output layer processing
148194
if isinstance(x, list):
149195
x = torch.cat([f(x_) for f, x_ in zip(self.layers[-1], x)], dim=1)
150196
else:

0 commit comments

Comments
 (0)