diff --git a/deepxde/model.py b/deepxde/model.py index 3584b76c3..e6acba9c9 100644 --- a/deepxde/model.py +++ b/deepxde/model.py @@ -4,6 +4,7 @@ from collections import OrderedDict import numpy as np +import orbax.checkpoint as ocp from . import config from . import display @@ -1011,6 +1012,21 @@ def save(self, save_path, protocol="backend", verbose=0): elif backend_name == "tensorflow": save_path += ".ckpt" self.net.save_weights(save_path) + elif backend_name == "jax": + # TODO: identify a better solution that complies with PEP8 + from flax.training import orbax_utils + save_path += ".ckpt" + checkpoint = { + "params": self.params, + "state": self.opt_state + } + self.checkpointer = ocp.PyTreeCheckpointer() + save_args = orbax_utils.save_args_from_target(checkpoint) + # `Force=True` option causes existing checkpoints to be + # overwritten, matching the PyTorch checkpointer behaviour. + self.checkpointer.save( + save_path, checkpoint, force=True, save_args=save_args + ) elif backend_name == "pytorch": save_path += ".pt" checkpoint = { @@ -1055,6 +1071,10 @@ def restore(self, save_path, device=None, verbose=0): self.saver.restore(self.sess, save_path) elif backend_name == "tensorflow": self.net.load_weights(save_path) + elif backend_name == "jax": + checkpoint = self.checkpointer.restore(save_path) + self.params, self.opt_state = checkpoint["params"], checkpoint["state"] + self.net.params, self.external_trainable_variables = self.params elif backend_name == "pytorch": if device is not None: checkpoint = torch.load(save_path, map_location=torch.device(device)) diff --git a/docs/requirements.txt b/docs/requirements.txt index 3f7b1e10f..64dda108e 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,5 +1,6 @@ matplotlib numpy +orbax-checkpoint scikit-learn scikit-optimize>=0.9.0 scipy diff --git a/pyproject.toml b/pyproject.toml index a1f7c55bd..ea3918577 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ classifiers = [ dependencies = [ "matplotlib", "numpy", + "orbax-checkpoint", "scikit-learn", "scikit-optimize>=0.9.0", "scipy", diff --git a/requirements.txt b/requirements.txt index 38608b103..2ea9e4f28 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ matplotlib numpy +orbax-checkpoint scikit-learn scikit-optimize>=0.9.0 scipy