diff --git a/deepxde/backend/jax/tensor.py b/deepxde/backend/jax/tensor.py index 61bf56e9f..4daedc401 100644 --- a/deepxde/backend/jax/tensor.py +++ b/deepxde/backend/jax/tensor.py @@ -175,3 +175,15 @@ def zeros(shape, dtype): def zeros_like(input_tensor): return jnp.zeros_like(input_tensor) + + +def l1_regularization(l1): + return lambda params: l1 * jnp.sum(jnp.concatenate([jnp.abs(w).flatten() for w in params])) + + +def l2_regularization(l2): + return lambda params: l2 * jnp.sum(jnp.concatenate([jnp.square(w).flatten() for w in params])) + + +def l1_l2_regularization(l1, l2): + return lambda params: l1_regularization(l1)(params) + l2_regularization(l2)(params) diff --git a/deepxde/model.py b/deepxde/model.py index 0a939dad3..af0bcfa91 100644 --- a/deepxde/model.py +++ b/deepxde/model.py @@ -437,12 +437,14 @@ def outputs_fn(inputs): # We use aux so that self.data.losses is a pure function. aux = [outputs_fn, ext_params] if ext_params else [outputs_fn] losses = losses_fn(targets, outputs_, loss_fn, inputs, self, aux=aux) - # TODO: Add regularization loss if not isinstance(losses, list): losses = [losses] losses = jax.numpy.asarray(losses) if self.loss_weights is not None: losses *= jax.numpy.asarray(self.loss_weights) + if self.net.regularizer is not None: + regul_loss = self.net.regularizer(jax.tree.leaves(nn_params["params"])) + losses = jax.numpy.concatenate([losses, regul_loss.reshape(1)]) return outputs_, losses @jax.jit diff --git a/deepxde/nn/jax/fnn.py b/deepxde/nn/jax/fnn.py index 61ba7a670..20e4afbd6 100644 --- a/deepxde/nn/jax/fnn.py +++ b/deepxde/nn/jax/fnn.py @@ -15,6 +15,7 @@ class FNN(NN): layer_sizes: Any activation: Any kernel_initializer: Any + regularization: Any = None params: Any = None _input_transform: Callable = None @@ -78,6 +79,7 @@ class PFNN(NN): layer_sizes: Any activation: Any kernel_initializer: Any + regularization: Any = None params: Any = None _input_transform: Callable = None diff --git a/deepxde/nn/jax/nn.py b/deepxde/nn/jax/nn.py index 200ea353d..ea241a3c7 100644 --- a/deepxde/nn/jax/nn.py +++ b/deepxde/nn/jax/nn.py @@ -1,14 +1,22 @@ from flax import linen as nn +from .. import regularizers + class NN(nn.Module): """Base class for all neural network modules.""" # All sub-modules should have the following variables: + # regularization: Any = None # params: Any = None # _input_transform: Optional[Callable] = None # _output_transform: Optional[Callable] = None - + + @property + def regularizer(self): + """Dynamically compute and return the regularizer function based on regularization.""" + return regularizers.get(self.regularization) + def apply_feature_transform(self, transform): """Compute the features by appling a transform to the network inputs, i.e., features = transform(inputs). Then, outputs = network(features).