Skip to content

Commit b79d2fd

Browse files
authored
JAX BACKEND: FNN supports regularization (#1968)
1 parent 80b8cf1 commit b79d2fd

File tree

4 files changed

+26
-2
lines changed

4 files changed

+26
-2
lines changed

deepxde/backend/jax/tensor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,15 @@ def zeros(shape, dtype):
175175

176176
def zeros_like(input_tensor):
177177
return jnp.zeros_like(input_tensor)
178+
179+
180+
def l1_regularization(l1):
181+
return lambda params: l1 * jnp.sum(jnp.concatenate([jnp.abs(w).flatten() for w in params]))
182+
183+
184+
def l2_regularization(l2):
185+
return lambda params: l2 * jnp.sum(jnp.concatenate([jnp.square(w).flatten() for w in params]))
186+
187+
188+
def l1_l2_regularization(l1, l2):
189+
return lambda params: l1_regularization(l1)(params) + l2_regularization(l2)(params)

deepxde/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,12 +437,14 @@ def outputs_fn(inputs):
437437
# We use aux so that self.data.losses is a pure function.
438438
aux = [outputs_fn, ext_params] if ext_params else [outputs_fn]
439439
losses = losses_fn(targets, outputs_, loss_fn, inputs, self, aux=aux)
440-
# TODO: Add regularization loss
441440
if not isinstance(losses, list):
442441
losses = [losses]
443442
losses = jax.numpy.asarray(losses)
444443
if self.loss_weights is not None:
445444
losses *= jax.numpy.asarray(self.loss_weights)
445+
if self.net.regularizer is not None:
446+
regul_loss = self.net.regularizer(jax.tree.leaves(nn_params["params"]))
447+
losses = jax.numpy.concatenate([losses, regul_loss.reshape(1)])
446448
return outputs_, losses
447449

448450
@jax.jit

deepxde/nn/jax/fnn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class FNN(NN):
1515
layer_sizes: Any
1616
activation: Any
1717
kernel_initializer: Any
18+
regularization: Any = None
1819

1920
params: Any = None
2021
_input_transform: Callable = None
@@ -78,6 +79,7 @@ class PFNN(NN):
7879
layer_sizes: Any
7980
activation: Any
8081
kernel_initializer: Any
82+
regularization: Any = None
8183

8284
params: Any = None
8385
_input_transform: Callable = None

deepxde/nn/jax/nn.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
from flax import linen as nn
22

3+
from .. import regularizers
4+
35

46
class NN(nn.Module):
57
"""Base class for all neural network modules."""
68

79
# All sub-modules should have the following variables:
10+
# regularization: Any = None
811
# params: Any = None
912
# _input_transform: Optional[Callable] = None
1013
# _output_transform: Optional[Callable] = None
11-
14+
15+
@property
16+
def regularizer(self):
17+
"""Dynamically compute and return the regularizer function based on regularization."""
18+
return regularizers.get(self.regularization)
19+
1220
def apply_feature_transform(self, transform):
1321
"""Compute the features by appling a transform to the network inputs, i.e.,
1422
features = transform(inputs). Then, outputs = network(features).

0 commit comments

Comments
 (0)