Skip to content

Commit 1a35f53

Browse files
angel-coreThe tunix Authors
authored andcommitted
Add Pathways support and proper testing.
PiperOrigin-RevId: 914897467
1 parent 26e5db2 commit 1a35f53

3 files changed

Lines changed: 132 additions & 1 deletion

File tree

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""Pathways tests for CheckpointManager."""
2+
3+
import os
4+
from unittest import mock
5+
6+
import jax
7+
import orbax.checkpoint as ocp
8+
from tunix.sft import checkpoint_manager
9+
from tunix.sft import checkpoint_manager_test
10+
11+
from GOOGLE_INTERNAL_PACKAGE_PATH.pyglib.contrib.g3_multiprocessing import g3_multiprocessing
12+
from GOOGLE_INTERNAL_PACKAGE_PATH.testing.pybase import googletest
13+
14+
15+
_ORIG_REGISTER = ocp.pathways.register_type_handlers
16+
17+
18+
# Required for Pathways for testing.
19+
def wrapped_register(*args, **kwargs):
20+
kwargs['thinmint_testing'] = True
21+
return _ORIG_REGISTER(*args, **kwargs)
22+
23+
24+
class DirectPathwaysPersistenceTest(
25+
checkpoint_manager_test.BaseCheckpointManagerTest
26+
):
27+
28+
def setUp(self):
29+
super().setUp()
30+
self.enterContext(
31+
mock.patch.dict(os.environ, {'ENABLE_PATHWAYS_PERSISTENCE': '1'})
32+
)
33+
# Intercept calls in production code to ensure thinmint_testing stays True
34+
self.mock_register = self.enterContext(
35+
mock.patch.object(
36+
ocp.pathways,
37+
'register_type_handlers',
38+
side_effect=wrapped_register,
39+
)
40+
)
41+
42+
def test_restore_different_sharding(self):
43+
self.skipTest('Pathways only supports jax.sharding.NamedSharding.')
44+
45+
def test_register_type_handlers(self):
46+
cp_path = f'{self.temp_path}/{self.id()}'
47+
_ = checkpoint_manager.CheckpointManager(cp_path)
48+
49+
self.mock_register.assert_called_once()
50+
_, kwargs = self.mock_register.call_args
51+
impl = kwargs['checkpointing_impl']
52+
self.assertEqual(
53+
impl,
54+
ocp.pathways.CheckpointingImpl.from_options(
55+
use_remote_python=True
56+
),
57+
)
58+
59+
60+
class DirectPathwaysNoPersistenceTest(
61+
checkpoint_manager_test.BaseCheckpointManagerTest
62+
):
63+
64+
def setUp(self):
65+
super().setUp()
66+
self.enterContext(
67+
mock.patch.dict(os.environ, {'ENABLE_PATHWAYS_PERSISTENCE': ''})
68+
)
69+
self.mock_register = self.enterContext(
70+
mock.patch.object(
71+
ocp.pathways,
72+
'register_type_handlers',
73+
side_effect=wrapped_register,
74+
)
75+
)
76+
77+
def test_restore_different_sharding(self):
78+
self.skipTest('Pathways only supports jax.sharding.NamedSharding.')
79+
80+
def test_register_type_handlers(self):
81+
cp_path = f'{self.temp_path}/{self.id()}'
82+
_ = checkpoint_manager.CheckpointManager(cp_path)
83+
84+
self.mock_register.assert_not_called()
85+
86+
if __name__ == '__main__':
87+
jax.config.parse_flags_with_absl()
88+
g3_multiprocessing.handle_test_main(googletest.main)

tests/sft/checkpoint_manager_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,34 @@ def setUp(self):
101101
axis_names=('fsdp', 'tp'),
102102
)
103103

104+
def test_handlers_options(self):
105+
"""Verifies OCDBT/Zarr3 options match active platform configuration."""
106+
cp_path = f'{self.temp_path}/{self.id()}'
107+
cp_manager = checkpoint_manager.CheckpointManager(cp_path)
108+
109+
platforms = jax.config.jax_platforms or ''
110+
is_pathways_or_proxy = 'proxy' in platforms or 'pathways' in platforms
111+
112+
handler = cp_manager._checkpoint_manager._checkpointer._handler # pytype: disable=attribute-error
113+
registry_entries = handler._handler_registry.get_all_entries() # pytype: disable=attribute-error
114+
115+
handlers = {}
116+
for (item_name, _), h in registry_entries.items():
117+
handlers[item_name] = h
118+
119+
self.assertIn('model_params', handlers)
120+
self.assertIn('optimizer_state', handlers)
121+
122+
if is_pathways_or_proxy:
123+
self.assertFalse(handlers['model_params']._use_ocdbt) # pytype: disable=attribute-error
124+
self.assertFalse(handlers['optimizer_state']._use_ocdbt) # pytype: disable=attribute-error
125+
else:
126+
self.assertTrue(handlers['model_params']._use_ocdbt) # pytype: disable=attribute-error
127+
self.assertTrue(handlers['optimizer_state']._use_ocdbt) # pytype: disable=attribute-error
128+
129+
self.assertFalse(handlers['model_params']._use_zarr3) # pytype: disable=attribute-error
130+
self.assertFalse(handlers['optimizer_state']._use_zarr3) # pytype: disable=attribute-error
131+
104132
def test_empty_root_directory(self):
105133
cp_manager = checkpoint_manager.CheckpointManager(root_directory=None)
106134
self.assertIsNone(cp_manager.latest_step())
@@ -299,6 +327,12 @@ def test_restore_with_backward_compatibility(self, ckpt_path):
299327
# The checkpoints in test_data is saved with StandardSave. The test is to
300328
# verify the checkpoint manager with PyTreeRestore can still restore the
301329
# checkpoints saved with StandardSave.
330+
if os.getenv('ENABLE_PATHWAYS_PERSISTENCE', '') == '1':
331+
self.skipTest(
332+
'Pathways persistence cannot read standard backwards-compatible'
333+
' checkpoints.'
334+
)
335+
302336
ckpt_manager = checkpoint_manager.CheckpointManager(
303337
os.path.join(os.path.dirname(__file__), ckpt_path)
304338
)

tunix/sft/checkpoint_manager.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,11 @@ def __init__(
5050
if root_directory is not None:
5151
# When using Pathways, the checkpoint manager only supports persistence
5252
# APIs now.
53-
if 'proxy' in os.getenv('JAX_PLATFORMS', ''):
53+
platforms = jax.config.jax_platforms or ''
54+
if (
55+
'proxy' in platforms
56+
or 'pathways' in platforms
57+
):
5458
item_handlers = {
5559
'model_params': ocp.PyTreeCheckpointHandler(
5660
use_ocdbt=False,
@@ -65,6 +69,11 @@ def __init__(
6569
logging.info(
6670
'Using persistence API for checkpointing with Pathways.'
6771
)
72+
ocp.pathways.register_type_handlers(
73+
checkpointing_impl=ocp.pathways.CheckpointingImpl.from_options(
74+
use_remote_python=True,
75+
)
76+
)
6877
else:
6978
logging.warning(
7079
'Checkpointing without the persistence API, be aware of potential'

0 commit comments

Comments
 (0)