@@ -1011,6 +1011,23 @@ def save(self, save_path, protocol="backend", verbose=0):
10111011 elif backend_name == "tensorflow" :
10121012 save_path += ".ckpt"
10131013 self .net .save_weights (save_path )
1014+ elif backend_name == "jax" :
1015+ # Lazy load Orbax to avoid a hard dependancy when using JAX
1016+ # TODO: identify a better solution that complies with PEP8
1017+ import orbax .checkpoint as ocp
1018+ from flax .training import orbax_utils
1019+ save_path += ".ckpt"
1020+ checkpoint = {
1021+ "params" : self .params ,
1022+ "state" : self .opt_state
1023+ }
1024+ self .checkpointer = ocp .PyTreeCheckpointer ()
1025+ save_args = orbax_utils .save_args_from_target (checkpoint )
1026+ # `Force=True` option causes existing checkpoints to be
1027+ # overwritten, matching the PyTorch checkpointer behaviour.
1028+ self .checkpointer .save (
1029+ save_path , checkpoint , force = True , save_args = save_args
1030+ )
10141031 elif backend_name == "pytorch" :
10151032 save_path += ".pt"
10161033 checkpoint = {
@@ -1055,6 +1072,10 @@ def restore(self, save_path, device=None, verbose=0):
10551072 self .saver .restore (self .sess , save_path )
10561073 elif backend_name == "tensorflow" :
10571074 self .net .load_weights (save_path )
1075+ elif backend_name == "jax" :
1076+ checkpoint = self .checkpointer .restore (save_path )
1077+ self .params , self .opt_state = checkpoint ["params" ], checkpoint ["state" ]
1078+ self .net .params , self .external_trainable_variables = self .params
10581079 elif backend_name == "pytorch" :
10591080 if device is not None :
10601081 checkpoint = torch .load (save_path , map_location = torch .device (device ))
0 commit comments