Skip to content

Commit 3b0825c

Browse files
angel-coreThe tunix Authors
authored andcommitted
Add Pathways support and proper testing.
PiperOrigin-RevId: 914897467
1 parent 26e5db2 commit 3b0825c

2 files changed

Lines changed: 44 additions & 1 deletion

File tree

tests/sft/checkpoint_manager_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,34 @@ def setUp(self):
101101
axis_names=('fsdp', 'tp'),
102102
)
103103

104+
def test_handlers_options(self):
105+
"""Verifies OCDBT/Zarr3 options match active platform configuration."""
106+
cp_path = f'{self.temp_path}/{self.id()}'
107+
cp_manager = checkpoint_manager.CheckpointManager(cp_path)
108+
109+
platforms = jax.config.jax_platforms or ''
110+
is_pathways_or_proxy = 'proxy' in platforms or 'pathways' in platforms
111+
112+
handler = cp_manager._checkpoint_manager._checkpointer._handler # pytype: disable=attribute-error
113+
registry_entries = handler._handler_registry.get_all_entries() # pytype: disable=attribute-error
114+
115+
handlers = {}
116+
for (item_name, _), h in registry_entries.items():
117+
handlers[item_name] = h
118+
119+
self.assertIn('model_params', handlers)
120+
self.assertIn('optimizer_state', handlers)
121+
122+
if is_pathways_or_proxy:
123+
self.assertFalse(handlers['model_params']._use_ocdbt) # pytype: disable=attribute-error
124+
self.assertFalse(handlers['optimizer_state']._use_ocdbt) # pytype: disable=attribute-error
125+
else:
126+
self.assertTrue(handlers['model_params']._use_ocdbt) # pytype: disable=attribute-error
127+
self.assertTrue(handlers['optimizer_state']._use_ocdbt) # pytype: disable=attribute-error
128+
129+
self.assertFalse(handlers['model_params']._use_zarr3) # pytype: disable=attribute-error
130+
self.assertFalse(handlers['optimizer_state']._use_zarr3) # pytype: disable=attribute-error
131+
104132
def test_empty_root_directory(self):
105133
cp_manager = checkpoint_manager.CheckpointManager(root_directory=None)
106134
self.assertIsNone(cp_manager.latest_step())
@@ -299,6 +327,12 @@ def test_restore_with_backward_compatibility(self, ckpt_path):
299327
# The checkpoints in test_data is saved with StandardSave. The test is to
300328
# verify the checkpoint manager with PyTreeRestore can still restore the
301329
# checkpoints saved with StandardSave.
330+
if os.getenv('ENABLE_PATHWAYS_PERSISTENCE', '') == '1':
331+
self.skipTest(
332+
'Pathways persistence cannot read standard backwards-compatible'
333+
' checkpoints.'
334+
)
335+
302336
ckpt_manager = checkpoint_manager.CheckpointManager(
303337
os.path.join(os.path.dirname(__file__), ckpt_path)
304338
)

tunix/sft/checkpoint_manager.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,11 @@ def __init__(
5050
if root_directory is not None:
5151
# When using Pathways, the checkpoint manager only supports persistence
5252
# APIs now.
53-
if 'proxy' in os.getenv('JAX_PLATFORMS', ''):
53+
platforms = jax.config.jax_platforms or ''
54+
if (
55+
'proxy' in platforms
56+
or 'pathways' in platforms
57+
):
5458
item_handlers = {
5559
'model_params': ocp.PyTreeCheckpointHandler(
5660
use_ocdbt=False,
@@ -65,6 +69,11 @@ def __init__(
6569
logging.info(
6670
'Using persistence API for checkpointing with Pathways.'
6771
)
72+
ocp.pathways.register_type_handlers(
73+
checkpointing_impl=ocp.pathways.CheckpointingImpl.from_options(
74+
use_remote_python=True,
75+
)
76+
)
6877
else:
6978
logging.warning(
7079
'Checkpointing without the persistence API, be aware of potential'

0 commit comments

Comments
 (0)