Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 93 additions & 47 deletions deepxde/nn/pytorch/fnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,14 @@ class PFNN(NN):
(how the layers are connected). If `layer_sizes[i]` is an int, it represents
one layer shared by all the outputs; if `layer_sizes[i]` is a list, it
represents `len(layer_sizes[i])` sub-layers, each of which is exclusively
used by one output. Note that `len(layer_sizes[i])` should equal the number
of outputs. Every number specifies the number of neurons in that layer.
used by one output. Every list in `layer_sizes` must have the same length
(= number of subnetworks). If the last element of `layer_sizes` is an int
preceded by a list, it must be equal to the number of subnetworks: all
subnetworks have an output size of 1 and are then concatenated. If the last
element is a list, it specifies the output size for each subnetwork before
concatenation.
activation: Activation function.
kernel_initializer: Initializer for the kernel weights.
"""

def __init__(self, layer_sizes, activation, kernel_initializer):
Expand All @@ -74,10 +80,31 @@ def __init__(self, layer_sizes, activation, kernel_initializer):
raise ValueError("must specify input and output sizes")
if not isinstance(layer_sizes[0], int):
raise ValueError("input size must be integer")
if not isinstance(layer_sizes[-1], int):
raise ValueError("output size must be integer")

n_output = layer_sizes[-1]
# Determine the number of subnetworks from the first list layer
list_layers = [
layer for layer in layer_sizes if isinstance(layer, (list, tuple))
]
if not list_layers:
raise ValueError(
"No list layers found; use FNN instead of PFNN for single subnetwork."
)
n_subnetworks = len(list_layers[0])
for layer in list_layers:
if len(layer) != n_subnetworks:
raise ValueError(
"All list layers must have the same length as the first list layer."
)

# Validate output layer if preceded by a list layer
if (
isinstance(layer_sizes[-1], int)
and isinstance(layer_sizes[-2], (list, tuple))
and layer_sizes[-1] != n_subnetworks
):
raise ValueError(
"If last layer is an int and previous is a list, the int must equal the number of subnetworks."
)

def make_linear(n_input, n_output):
linear = torch.nn.Linear(n_input, n_output, dtype=config.real(torch))
Expand All @@ -86,49 +113,64 @@ def make_linear(n_input, n_output):
return linear

self.layers = torch.nn.ModuleList()

# Process hidden layers (excluding the output layer)
for i in range(1, len(layer_sizes) - 1):
prev_layer_size = layer_sizes[i - 1]
curr_layer_size = layer_sizes[i]
if isinstance(curr_layer_size, (list, tuple)):
if len(curr_layer_size) != n_output:
raise ValueError(
"number of sub-layers should equal number of network outputs"
)
if isinstance(prev_layer_size, (list, tuple)):
# e.g. [8, 8, 8] -> [16, 16, 16]
self.layers.append(
torch.nn.ModuleList(
[
make_linear(prev_layer_size[j], curr_layer_size[j])
for j in range(n_output)
]
)
)
else: # e.g. 64 -> [8, 8, 8]
self.layers.append(
torch.nn.ModuleList(
[
make_linear(prev_layer_size, curr_layer_size[j])
for j in range(n_output)
]
)
)
else: # e.g. 64 -> 64
if not isinstance(prev_layer_size, int):
raise ValueError(
"cannot rejoin parallel subnetworks after splitting"
)
self.layers.append(make_linear(prev_layer_size, curr_layer_size))

# output layers
if isinstance(layer_sizes[-2], (list, tuple)): # e.g. [3, 3, 3] -> 3
self.layers.append(
torch.nn.ModuleList(
[make_linear(layer_sizes[-2][j], 1) for j in range(n_output)]
)
)
prev_layer = layer_sizes[i - 1]
curr_layer = layer_sizes[i]

if isinstance(curr_layer, (list, tuple)):
# Parallel layer
if isinstance(prev_layer, (list, tuple)):
# Previous is parallel: each subnetwork input is previous subnetwork output
sub_layers = [
make_linear(prev_layer[j], curr_layer[j])
for j in range(n_subnetworks)
]
else:
# Previous is shared: all subnetworks take the same input
sub_layers = [
make_linear(prev_layer, curr_layer[j])
for j in range(n_subnetworks)
]
self.layers.append(torch.nn.ModuleList(sub_layers))
else:
# Shared layer
if isinstance(prev_layer, (list, tuple)):
# Previous is parallel: concatenate outputs
input_size = sum(prev_layer)
else:
input_size = prev_layer
self.layers.append(make_linear(input_size, curr_layer))

# Process output layer
prev_output_layer = layer_sizes[-2]
output_layer = layer_sizes[-1]

if isinstance(output_layer, (list, tuple)):
if isinstance(prev_output_layer, (list, tuple)):
# Each subnetwork input is corresponding previous output
output_layers = [
make_linear(prev_output_layer[j], output_layer[j])
for j in range(n_subnetworks)
]
else:
# All subnetworks take the same shared input
output_layers = [
make_linear(prev_output_layer, output_layer[j])
for j in range(n_subnetworks)
]
self.layers.append(torch.nn.ModuleList(output_layers))
else:
self.layers.append(make_linear(layer_sizes[-2], n_output))
if isinstance(prev_output_layer, (list, tuple)):
# Each subnetwork outputs 1 and concatenates to output_layer size
output_layers = [
make_linear(prev_output_layer[j], 1) for j in range(n_subnetworks)
]
self.layers.append(torch.nn.ModuleList(output_layers))
else:
# Shared output layer
self.layers.append(make_linear(prev_output_layer, output_layer))

def forward(self, inputs):
x = inputs
Expand All @@ -137,14 +179,18 @@ def forward(self, inputs):

for layer in self.layers[:-1]:
if isinstance(layer, torch.nn.ModuleList):
# Parallel layer processing
if isinstance(x, list):
x = [self.activation(f(x_)) for f, x_ in zip(layer, x)]
else:
x = [self.activation(f(x)) for f in layer]
else:
# Shared layer processing (concatenate if necessary)
if isinstance(x, list):
x = torch.cat(x, dim=1)
x = self.activation(layer(x))

# output layers
# Output layer processing
if isinstance(x, list):
x = torch.cat([f(x_) for f, x_ in zip(self.layers[-1], x)], dim=1)
else:
Expand Down