Skip to content

Commit e42a6b7

Browse files
committed
move regularizer to base nn class
1 parent 12710f0 commit e42a6b7

File tree

2 files changed

+7
-14
lines changed

2 files changed

+7
-14
lines changed

deepxde/nn/jax/fnn.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
from .nn import NN
88
from .. import activations
99
from .. import initializers
10-
from .. import regularizers
11-
1210

1311
class FNN(NN):
1412
"""Fully-connected neural network."""
@@ -22,11 +20,6 @@ class FNN(NN):
2220
_input_transform: Callable = None
2321
_output_transform: Callable = None
2422

25-
@property
26-
def regularizer(self):
27-
"""Dynamically compute and return the regularizer function based on regularization."""
28-
return regularizers.get(self.regularization)
29-
3023
def setup(self):
3124
# TODO: implement get regularizer
3225
if isinstance(self.activation, list):
@@ -91,11 +84,6 @@ class PFNN(NN):
9184
_input_transform: Callable = None
9285
_output_transform: Callable = None
9386

94-
@property
95-
def regularizer(self):
96-
"""Dynamically compute and return the regularizer function based on regularization."""
97-
return regularizers.get(self.regularization)
98-
9987
def setup(self):
10088
if len(self.layer_sizes) <= 1:
10189
raise ValueError("must specify input and output sizes")

deepxde/nn/jax/nn.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from flax import linen as nn
2-
2+
from .. import regularizers
33

44
class NN(nn.Module):
55
"""Base class for all neural network modules."""
@@ -8,7 +8,12 @@ class NN(nn.Module):
88
# params: Any = None
99
# _input_transform: Optional[Callable] = None
1010
# _output_transform: Optional[Callable] = None
11-
11+
12+
@property
13+
def regularizer(self):
14+
"""Dynamically compute and return the regularizer function based on regularization."""
15+
return regularizers.get(self.regularization)
16+
1217
def apply_feature_transform(self, transform):
1318
"""Compute the features by appling a transform to the network inputs, i.e.,
1419
features = transform(inputs). Then, outputs = network(features).

0 commit comments

Comments
 (0)