Skip to content

Commit b27d057

Browse files
authored
Backend jax supports initializers, improve FNN (#547)
1 parent 8761595 commit b27d057

File tree

7 files changed

+41
-21
lines changed

7 files changed

+41
-21
lines changed

deepxde/model.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -269,8 +269,6 @@ def closure():
269269

270270
def _compile_jax(self, lr, loss_fn, decay, loss_weights):
271271
"""jax"""
272-
import optax
273-
274272
# initialize network's parameters
275273
# TODO: Init should move to network module, because we don't know how to init here, e.g., DeepONet has two inputs.
276274
# random seed should use a random number, or be specified by users
@@ -303,7 +301,7 @@ def loss_function(params):
303301
) # jax.value_and_grad seems to be slightly faster than jax.grad for function approximation
304302
grads = grad_fn(params)
305303
updates, new_opt_state = self.opt.update(grads, opt_state)
306-
new_params = optax.apply_updates(params, updates)
304+
new_params = optimizers.apply_updates(params, updates)
307305
return new_params, new_opt_state
308306

309307
def outputs(training, inputs):

deepxde/nn/initializers.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import math
44

55
from .. import config
6-
from ..backend import backend_name, tf, torch
6+
from ..backend import backend_name, tf, torch, jax
77

88

99
class VarianceScalingStacked:
@@ -139,10 +139,24 @@ def initializer_dict_torch():
139139
}
140140

141141

142+
def initializer_dict_jax():
143+
return {
144+
"Glorot normal": jax.nn.initializers.glorot_normal(),
145+
"Glorot uniform": jax.nn.initializers.glorot_uniform(),
146+
"He normal": jax.nn.initializers.he_normal(),
147+
"He uniform": jax.nn.initializers.he_uniform(),
148+
"Lecun normal": jax.nn.initializers.lecun_normal(),
149+
"Lecun uniform": jax.nn.initializers.lecun_uniform(),
150+
"zeros": jax.nn.initializers.zeros,
151+
}
152+
153+
142154
if backend_name in ["tensorflow.compat.v1", "tensorflow"]:
143155
INITIALIZER_DICT = initializer_dict_tf()
144156
elif backend_name == "pytorch":
145157
INITIALIZER_DICT = initializer_dict_torch()
158+
elif backend_name == "jax":
159+
INITIALIZER_DICT = initializer_dict_jax()
146160

147161

148162
def get(identifier):

deepxde/nn/jax/fnn.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,28 @@
1-
from typing import Any
1+
from typing import Any, Callable
22

33
import jax
44
from flax import linen as nn
55

66
from .nn import NN
7+
from .. import activations
8+
from .. import initializers
79

810

911
class FNN(NN):
1012
"""Fully-connected neural network"""
1113

12-
layer_sizes: Any = None
13-
activation: Any = None
14-
kernel_initializer: Any = None
14+
layer_sizes: Any
15+
activation: Any
16+
kernel_initializer: Any
17+
training: bool = True
18+
_input_transform: Callable = None
19+
_output_transform: Callable = None
20+
params: Any = None
1521

1622
def setup(self):
17-
# TODO: implement get activation, get initializer
18-
self._activation = jax.nn.tanh
19-
kernel_initializer = jax.nn.initializers.glorot_normal()
23+
# TODO: implement get regularizer
24+
self._activation = activations.get(self.activation)
25+
kernel_initializer = initializers.get(self.kernel_initializer)
2026
initializer = jax.nn.initializers.zeros
2127

2228
self.denses = [

deepxde/nn/jax/nn.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
1-
from typing import Any
2-
31
from flax import linen as nn
42

53

64
class NN(nn.Module):
75
"""Base class for all neural network modules."""
86

9-
training: Any = True
10-
params: Any = None
11-
_input_transform: Any = None
12-
_output_transform: Any = None
7+
# all sub-modules should have the following init-only variables:
8+
# training: bool = True
9+
# params: Any = None
10+
# _input_transform: Optional[Callable] = None
11+
# _output_transform: Optional[Callable] = None
1312

1413
def apply_feature_transform(self, transform):
1514
"""Compute the features by appling a transform to the network inputs, i.e.,

deepxde/optimizers/jax/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
__all__ = ["get", "is_external_optimizer"]
1+
__all__ = ["get", "is_external_optimizer", "apply_updates"]
22

3-
from .optimizers import get, is_external_optimizer
3+
from .optimizers import get, is_external_optimizer, apply_updates

deepxde/optimizers/jax/optimizers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1-
__all__ = ["get", "is_external_optimizer"]
1+
__all__ = ["get", "is_external_optimizer", "apply_updates"]
22

33
import jax
44
import optax
55

66

7+
apply_updates = optax.apply_updates
8+
9+
710
def is_external_optimizer(optimizer):
811
# TODO: add external optimizers
912
return False

examples/function/func.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch"""
1+
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch, jax"""
22
import deepxde as dde
33
import numpy as np
44

0 commit comments

Comments
 (0)