Skip to content

Commit 03d5dde

Browse files
committed
format with black
1 parent d6abac9 commit 03d5dde

File tree

1 file changed

+31
-9
lines changed

1 file changed

+31
-9
lines changed

deepxde/nn/pytorch/fnn.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ class PFNN(NN):
7070
kernel_initializer: Initializer for the kernel weights.
7171
"""
7272

73-
def __init__(self, layer_sizes, activation, kernel_initializer, regularization=None):
73+
def __init__(
74+
self, layer_sizes, activation, kernel_initializer, regularization=None
75+
):
7476
super().__init__()
7577
self.activation = activations.get(activation)
7678
initializer = initializers.get(kernel_initializer)
@@ -83,13 +85,19 @@ def __init__(self, layer_sizes, activation, kernel_initializer, regularization=N
8385
raise ValueError("input size must be integer")
8486

8587
# Determine the number of subnetworks from the first list layer
86-
list_layers = [layer for layer in layer_sizes if isinstance(layer, (list, tuple))]
88+
list_layers = [
89+
layer for layer in layer_sizes if isinstance(layer, (list, tuple))
90+
]
8791
if not list_layers:
88-
raise ValueError("No list layers found; use FNN instead of PFNN for single subnetwork.")
92+
raise ValueError(
93+
"No list layers found; use FNN instead of PFNN for single subnetwork."
94+
)
8995
n_subnetworks = len(list_layers[0])
9096
for layer in list_layers:
9197
if len(layer) != n_subnetworks:
92-
raise ValueError("All list layers must have the same length as the first list layer.")
98+
raise ValueError(
99+
"All list layers must have the same length as the first list layer."
100+
)
93101

94102
# Validate output layer if preceded by a list layer
95103
if (
@@ -118,10 +126,16 @@ def make_linear(n_input, n_output):
118126
# Parallel layer
119127
if isinstance(prev_layer, (list, tuple)):
120128
# Previous is parallel: each subnetwork input is previous subnetwork output
121-
sub_layers = [make_linear(prev_layer[j], curr_layer[j]) for j in range(n_subnetworks)]
129+
sub_layers = [
130+
make_linear(prev_layer[j], curr_layer[j])
131+
for j in range(n_subnetworks)
132+
]
122133
else:
123134
# Previous is shared: all subnetworks take the same input
124-
sub_layers = [make_linear(prev_layer, curr_layer[j]) for j in range(n_subnetworks)]
135+
sub_layers = [
136+
make_linear(prev_layer, curr_layer[j])
137+
for j in range(n_subnetworks)
138+
]
125139
self.layers.append(torch.nn.ModuleList(sub_layers))
126140
else:
127141
# Shared layer
@@ -139,15 +153,23 @@ def make_linear(n_input, n_output):
139153
if isinstance(output_layer, (list, tuple)):
140154
if isinstance(prev_output_layer, (list, tuple)):
141155
# Each subnetwork input is corresponding previous output
142-
output_layers = [make_linear(prev_output_layer[j], output_layer[j]) for j in range(n_subnetworks)]
156+
output_layers = [
157+
make_linear(prev_output_layer[j], output_layer[j])
158+
for j in range(n_subnetworks)
159+
]
143160
else:
144161
# All subnetworks take the same shared input
145-
output_layers = [make_linear(prev_output_layer, output_layer[j]) for j in range(n_subnetworks)]
162+
output_layers = [
163+
make_linear(prev_output_layer, output_layer[j])
164+
for j in range(n_subnetworks)
165+
]
146166
self.layers.append(torch.nn.ModuleList(output_layers))
147167
else:
148168
if isinstance(prev_output_layer, (list, tuple)):
149169
# Each subnetwork outputs 1 and concatenates to output_layer size
150-
output_layers = [make_linear(prev_output_layer[j], 1) for j in range(n_subnetworks)]
170+
output_layers = [
171+
make_linear(prev_output_layer[j], 1) for j in range(n_subnetworks)
172+
]
151173
self.layers.append(torch.nn.ModuleList(output_layers))
152174
else:
153175
# Shared output layer

0 commit comments

Comments
 (0)