@@ -101,6 +101,34 @@ def setUp(self):
101101 axis_names = ('fsdp' , 'tp' ),
102102 )
103103
104+ def test_handlers_options (self ):
105+ """Verifies OCDBT/Zarr3 options match active platform configuration."""
106+ cp_path = f'{ self .temp_path } /{ self .id ()} '
107+ cp_manager = checkpoint_manager .CheckpointManager (cp_path )
108+
109+ platforms = jax .config .jax_platforms or ''
110+ is_pathways_or_proxy = 'proxy' in platforms or 'pathways' in platforms
111+
112+ handler = cp_manager ._checkpoint_manager ._checkpointer ._handler # pytype: disable=attribute-error
113+ registry_entries = handler ._handler_registry .get_all_entries () # pytype: disable=attribute-error
114+
115+ handlers = {}
116+ for (item_name , _ ), h in registry_entries .items ():
117+ handlers [item_name ] = h
118+
119+ self .assertIn ('model_params' , handlers )
120+ self .assertIn ('optimizer_state' , handlers )
121+
122+ if is_pathways_or_proxy :
123+ self .assertFalse (handlers ['model_params' ]._use_ocdbt ) # pytype: disable=attribute-error
124+ self .assertFalse (handlers ['optimizer_state' ]._use_ocdbt ) # pytype: disable=attribute-error
125+ else :
126+ self .assertTrue (handlers ['model_params' ]._use_ocdbt ) # pytype: disable=attribute-error
127+ self .assertTrue (handlers ['optimizer_state' ]._use_ocdbt ) # pytype: disable=attribute-error
128+
129+ self .assertFalse (handlers ['model_params' ]._use_zarr3 ) # pytype: disable=attribute-error
130+ self .assertFalse (handlers ['optimizer_state' ]._use_zarr3 ) # pytype: disable=attribute-error
131+
104132 def test_empty_root_directory (self ):
105133 cp_manager = checkpoint_manager .CheckpointManager (root_directory = None )
106134 self .assertIsNone (cp_manager .latest_step ())
@@ -299,6 +327,12 @@ def test_restore_with_backward_compatibility(self, ckpt_path):
299327 # The checkpoints in test_data is saved with StandardSave. The test is to
300328 # verify the checkpoint manager with PyTreeRestore can still restore the
301329 # checkpoints saved with StandardSave.
330+ if os .getenv ('ENABLE_PATHWAYS_PERSISTENCE' , '' ) == '1' :
331+ self .skipTest (
332+ 'Pathways persistence cannot read standard backwards-compatible'
333+ ' checkpoints.'
334+ )
335+
302336 ckpt_manager = checkpoint_manager .CheckpointManager (
303337 os .path .join (os .path .dirname (__file__ ), ckpt_path )
304338 )
0 commit comments