Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions deepxde/backend/jax/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,15 @@ def zeros(shape, dtype):

def zeros_like(input_tensor):
return jnp.zeros_like(input_tensor)


def l1_regularization(l1):
return lambda params: l1 * jnp.sum(jnp.concatenate([jnp.abs(w).flatten() for w in params]))


def l2_regularization(l2):
return lambda params: l2 * jnp.sum(jnp.concatenate([jnp.square(w).flatten() for w in params]))


def l1_l2_regularization(l1, l2):
return lambda params: l1_regularization(l1)(params) + l2_regularization(l2)(params)
4 changes: 3 additions & 1 deletion deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,12 +437,14 @@ def outputs_fn(inputs):
# We use aux so that self.data.losses is a pure function.
aux = [outputs_fn, ext_params] if ext_params else [outputs_fn]
losses = losses_fn(targets, outputs_, loss_fn, inputs, self, aux=aux)
# TODO: Add regularization loss
if not isinstance(losses, list):
losses = [losses]
losses = jax.numpy.asarray(losses)
if self.loss_weights is not None:
losses *= jax.numpy.asarray(self.loss_weights)
if self.net.regularizer is not None:
regul_loss = self.net.regularizer(jax.tree.leaves(nn_params["params"]))
losses = jax.numpy.concatenate([losses, regul_loss.reshape(1)])
return outputs_, losses

@jax.jit
Expand Down
2 changes: 2 additions & 0 deletions deepxde/nn/jax/fnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class FNN(NN):
layer_sizes: Any
activation: Any
kernel_initializer: Any
regularization: Any = None

params: Any = None
_input_transform: Callable = None
Expand Down Expand Up @@ -78,6 +79,7 @@ class PFNN(NN):
layer_sizes: Any
activation: Any
kernel_initializer: Any
regularization: Any = None

params: Any = None
_input_transform: Callable = None
Expand Down
10 changes: 9 additions & 1 deletion deepxde/nn/jax/nn.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
from flax import linen as nn

from .. import regularizers


class NN(nn.Module):
"""Base class for all neural network modules."""

# All sub-modules should have the following variables:
# regularization: Any = None
# params: Any = None
# _input_transform: Optional[Callable] = None
# _output_transform: Optional[Callable] = None


@property
def regularizer(self):
"""Dynamically compute and return the regularizer function based on regularization."""
return regularizers.get(self.regularization)

def apply_feature_transform(self, transform):
"""Compute the features by appling a transform to the network inputs, i.e.,
features = transform(inputs). Then, outputs = network(features).
Expand Down