Skip to content

Commit 458d880

Browse files
committed
Added JAX checkpointing via Orbax
1 parent 0b518c6 commit 458d880

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

deepxde/model.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,6 +1011,23 @@ def save(self, save_path, protocol="backend", verbose=0):
10111011
elif backend_name == "tensorflow":
10121012
save_path += ".ckpt"
10131013
self.net.save_weights(save_path)
1014+
elif backend_name == "jax":
1015+
# Lazy load Orbax to avoid a hard dependancy when using JAX
1016+
# TODO: identify a better solution that complies with PEP8
1017+
import orbax.checkpoint as ocp
1018+
from flax.training import orbax_utils
1019+
save_path += ".ckpt"
1020+
checkpoint = {
1021+
"params": self.params,
1022+
"state": self.opt_state
1023+
}
1024+
self.checkpointer = ocp.PyTreeCheckpointer()
1025+
save_args = orbax_utils.save_args_from_target(checkpoint)
1026+
# `Force=True` option causes existing checkpoints to be
1027+
# overwritten, matching the PyTorch checkpointer behaviour.
1028+
self.checkpointer.save(
1029+
save_path, checkpoint, force=True, save_args=save_args
1030+
)
10141031
elif backend_name == "pytorch":
10151032
save_path += ".pt"
10161033
checkpoint = {
@@ -1055,6 +1072,10 @@ def restore(self, save_path, device=None, verbose=0):
10551072
self.saver.restore(self.sess, save_path)
10561073
elif backend_name == "tensorflow":
10571074
self.net.load_weights(save_path)
1075+
elif backend_name == "jax":
1076+
checkpoint = self.checkpointer.restore(save_path)
1077+
self.params, self.opt_state = checkpoint["params"], checkpoint["state"]
1078+
self.net.params, self.external_trainable_variables = self.params
10581079
elif backend_name == "pytorch":
10591080
if device is not None:
10601081
checkpoint = torch.load(save_path, map_location=torch.device(device))

0 commit comments

Comments
 (0)