Skip to content

Commit 5539ee5

Browse files
author
Orbax Authors
committed
#p2p Add Grain data iterator checkpointing to P2P
PiperOrigin-RevId: 863828755
1 parent 230ecb1 commit 5539ee5

File tree

8 files changed

+227
-9
lines changed

8 files changed

+227
-9
lines changed

checkpoint/orbax/checkpoint/experimental/emergency/p2p/args.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,41 @@
1414

1515
"""P2P composite checkpoint argument."""
1616

17-
from typing import final
17+
from typing import Any, final
1818
from orbax.checkpoint import args as args_lib
1919
from orbax.checkpoint.experimental.emergency.p2p import constants
20+
from orbax.checkpoint.experimental.emergency.p2p import utils
21+
22+
23+
def _check_data_iter(value: Any):
24+
"""Checks if data_iter is valid."""
25+
if utils.pygrain() is None:
26+
raise ImportError(
27+
'grain library is not available. Please install grain to use data_iter.'
28+
)
29+
if not isinstance(
30+
value,
31+
(
32+
utils.pygrain().PyGrainCheckpointSave,
33+
utils.pygrain().PyGrainCheckpointRestore,
34+
),
35+
):
36+
raise TypeError(f'Unsupported type for data_iter: {type(value)}')
2037

2138

2239
@final
2340
class Composite(args_lib.Composite):
24-
"""Composite argument that only supports 'state' key."""
41+
"""Composite argument that supports 'state' and 'data_iter' keys."""
2542

2643
def __init__(self, *args, **kwargs):
2744
super().__init__(*args, **kwargs)
28-
if constants.STATE_SUBDIR not in self or len(self) > 1:
45+
if constants.STATE_SUBDIR not in self:
2946
raise ValueError(
30-
f'Composite must contain "{constants.STATE_SUBDIR}" key and no other'
31-
f' keys: {list(self.keys())}'
47+
f'Composite must contain "{constants.STATE_SUBDIR}" key:'
48+
f' {list(self.keys())}'
3249
)
50+
for key in self:
51+
if key not in [constants.STATE_SUBDIR, constants.DATA_ITER_KEY]:
52+
raise ValueError(f'Unsupported key in Composite: {key}')
53+
if key == constants.DATA_ITER_KEY:
54+
_check_data_iter(self[key])

checkpoint/orbax/checkpoint/experimental/emergency/p2p/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
# Directory names
1818
P2P_RESTORE_DIR_NAME = 'p2p_restore'
1919
STATE_SUBDIR = 'state'
20+
DATA_ITER_KEY = 'data_iter'
2021
PROCESS_SUBDIR_PREFIX = 'ocdbt.process_'
22+
PYGRAIN_STATES_FILENAME = 'pygrain_states.json'
2123

2224
# Tuning for high-throughput networks (16MB buffers)
2325
SOCKET_BUFFER_SIZE = 16 * 1024 * 1024

checkpoint/orbax/checkpoint/experimental/emergency/p2p/local.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,100 @@
1414

1515
"""Internal checkpoint manager for local P2P storage logic."""
1616

17+
import dataclasses
18+
import json
1719
from typing import Any, Sequence, final
1820

1921
from absl import logging
2022
from etils import epath
2123
import jax
2224
import orbax.checkpoint as ocp
25+
from orbax.checkpoint import args as args_lib
2326
from orbax.checkpoint import checkpoint_manager
2427
from orbax.checkpoint import type_handlers
2528
from orbax.checkpoint._src.multihost import multihost
2629
from orbax.checkpoint._src.serialization import type_handler_registry
2730
from orbax.checkpoint.experimental.emergency import checkpoint_manager as emergency_checkpoint_manager
31+
from orbax.checkpoint.experimental.emergency import path as emergency_path
2832
from orbax.checkpoint.experimental.emergency.p2p import args as p2p_args_lib
33+
from orbax.checkpoint.experimental.emergency.p2p import constants
2934
from orbax.checkpoint.experimental.emergency.p2p import utils
3035

3136

37+
if utils.pygrain() is not None:
38+
39+
class _LocalPyGrainHandlerMixin(utils.pygrain().PyGrainCheckpointHandler):
40+
"""Mixin for Local PyGrain handler."""
41+
42+
def __init__(self, process_index: int):
43+
self._process_index = process_index
44+
45+
def save(self, directory: epath.Path, item: Any = None, args: Any = None):
46+
"""Saves the PyGrain iterator state to a JSON file."""
47+
item = item or args.item
48+
state = item.get_state()
49+
50+
if isinstance(item, utils.pygrain().DatasetIterator):
51+
state_val = state
52+
else:
53+
# DataLoaderIterator state is bytes, decode to string for JSON
54+
state_val = state.decode()
55+
56+
local_data = {
57+
str(self._process_index): state_val,
58+
}
59+
60+
all_data = emergency_path.sync_global_data(local_data)
61+
62+
combined_data = {}
63+
for entry in all_data:
64+
combined_data.update(entry)
65+
66+
(directory / constants.PYGRAIN_STATES_FILENAME).write_text(
67+
json.dumps(combined_data, indent=2)
68+
)
69+
70+
def restore(
71+
self, directory: epath.Path, item: Any = None, args: Any = None
72+
):
73+
"""Restores the PyGrain iterator state from a JSON file."""
74+
item = item or args.item
75+
path = directory / constants.PYGRAIN_STATES_FILENAME
76+
77+
if not path.exists():
78+
raise ValueError(f'PyGrain states not found at {path}')
79+
80+
combined_data = json.loads(path.read_text())
81+
my_key = str(self._process_index)
82+
83+
if my_key not in combined_data:
84+
raise ValueError(
85+
f'Process index {self._process_index} not found in {path}'
86+
)
87+
88+
state_val = combined_data[my_key]
89+
90+
if isinstance(item, utils.pygrain().DatasetIterator):
91+
# DatasetIterator expects a dict
92+
state = state_val
93+
else:
94+
# DataLoaderIterator expects bytes
95+
state = state_val.encode()
96+
97+
item.set_state(state)
98+
return item
99+
100+
@ocp.args.register_with_handler(_LocalPyGrainHandlerMixin, for_save=True)
101+
@dataclasses.dataclass
102+
class LocalPyGrainSave(ocp.args.CheckpointArgs):
103+
item: Any
104+
105+
@ocp.args.register_with_handler(_LocalPyGrainHandlerMixin, for_restore=True)
106+
@dataclasses.dataclass
107+
class LocalPyGrainRestore(utils.pygrain().PyGrainCheckpointRestore):
108+
item: Any
109+
110+
32111
@final
33112
class LocalCheckpointManager:
34113
"""Wrapper around Orbax CheckpointManager for local P2P shards."""
@@ -77,10 +156,16 @@ def __init__(
77156
type_handler_registry=local_registry,
78157
)
79158

159+
item_handlers = dict(state=handler)
160+
if utils.pygrain() is not None:
161+
item_handlers[constants.DATA_ITER_KEY] = _LocalPyGrainHandlerMixin(
162+
self._process_index
163+
)
164+
80165
self._manager = checkpoint_manager.CheckpointManager(
81166
self._directory,
82167
options=p2p_specific_options,
83-
item_handlers=dict(state=handler),
168+
item_handlers=item_handlers,
84169
)
85170

86171
@property
@@ -122,6 +207,14 @@ def save(
122207
force: bool = False,
123208
) -> bool:
124209
"""Saves the checkpoint."""
210+
if utils.pygrain() is not None and constants.DATA_ITER_KEY in args:
211+
original_save = args[constants.DATA_ITER_KEY]
212+
args_dict = dict(args.items())
213+
args_dict[constants.DATA_ITER_KEY] = LocalPyGrainSave(
214+
item=original_save.item
215+
)
216+
args = args_lib.Composite(**args_dict)
217+
125218
return self._manager.save(step, args=args, force=force)
126219

127220
def restore(
@@ -146,6 +239,14 @@ def restore(
146239
)
147240
raise ValueError(error_msg)
148241

242+
if utils.pygrain() is not None and args and constants.DATA_ITER_KEY in args:
243+
original_restore = args[constants.DATA_ITER_KEY]
244+
args_dict = dict(args.items())
245+
args_dict[constants.DATA_ITER_KEY] = LocalPyGrainRestore(
246+
original_restore.item
247+
)
248+
args = args_lib.Composite(**args_dict)
249+
149250
# 2. Delegate to Orbax
150251
restored = self._manager.restore(
151252
step,

checkpoint/orbax/checkpoint/experimental/emergency/p2p/local_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,9 @@ def test_save_restore(self, unused_process_index):
113113
),
114114
)
115115

116-
jax.tree_util.tree_map(np.testing.assert_array_equal, state, restored.state)
116+
jax.tree_util.tree_map(
117+
np.testing.assert_array_equal, state, restored['state']
118+
)
117119
manager.close()
118120

119121

checkpoint/orbax/checkpoint/experimental/emergency/p2p/persistent.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from orbax.checkpoint._src.serialization import type_handlers
3030
from orbax.checkpoint.experimental.emergency import checkpoint_manager as emergency_checkpoint_manager
3131
from orbax.checkpoint.experimental.emergency.p2p import args as p2p_args_lib
32+
from orbax.checkpoint.experimental.emergency.p2p import constants
33+
from orbax.checkpoint.experimental.emergency.p2p import utils
3234

3335
_PRIMARY_REPLICA_ID = 0
3436
PyTree = Any
@@ -115,10 +117,14 @@ def __init__(
115117
enable_async_checkpointing=True,
116118
)
117119

120+
item_handlers = dict(state=_create_persistent_handler(mp_options))
121+
if utils.pygrain() is not None:
122+
item_handlers['data_iter'] = utils.pygrain().PyGrainCheckpointHandler()
123+
118124
self._manager = checkpoint_manager.CheckpointManager(
119125
self._directory,
120126
options=internal_options,
121-
item_handlers=dict(state=_create_persistent_handler(mp_options)),
127+
item_handlers=item_handlers,
122128
)
123129

124130
@property
@@ -169,8 +175,11 @@ def restore(
169175
abstract_state, sharding_tree
170176
),
171177
)
178+
restore_kwargs = {'state': restore_args_obj}
179+
if constants.DATA_ITER_KEY in args:
180+
restore_kwargs[constants.DATA_ITER_KEY] = args.data_iter
172181
return self._manager.restore(
173-
step, args=p2p_args_lib.Composite(state=restore_args_obj)
182+
step, args=p2p_args_lib.Composite(**restore_kwargs)
174183
)
175184

176185
def delete(self, step: int):

checkpoint/orbax/checkpoint/experimental/emergency/p2p/persistent_test.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from absl.testing import absltest
1818
from etils import epath
19+
import grain.python as pygrain
1920
import jax
2021
import numpy as np
2122
from orbax.checkpoint import args as args_lib
@@ -220,6 +221,69 @@ def _to_abstract(x):
220221
test_utils.assert_tree_equal(self, state, restored_state)
221222
manager.close()
222223

224+
@mock.patch(
225+
'orbax.checkpoint._src.multihost.multihost.process_index', return_value=0
226+
)
227+
def test_save_restore_with_grain_iterator(self, unused_process_index):
228+
self._patch_process_index(process_index=0)
229+
# persistent checkpoint manager with multiprocessing only works with a
230+
# unified storage.
231+
self.enter_context(mock.patch.object(jax, 'process_count', return_value=1))
232+
devices = np.array([
233+
[MockDevice(0, 0)],
234+
])
235+
mesh = mock.Mock(
236+
spec=jax.sharding.Mesh,
237+
devices=devices,
238+
axis_names=('replica', 'data'),
239+
shape={'replica': 1, 'data': 1},
240+
shape_tuple=devices.shape,
241+
size=devices.size,
242+
)
243+
244+
manager = persistent.PersistentCheckpointManager(
245+
self.directory, mesh, replica_axis_index=0, options=self.options
246+
)
247+
248+
ds = pygrain.MapDataset.source(list(range(10)))
249+
dl = pygrain.DataLoader(
250+
data_source=ds,
251+
sampler=pygrain.SequentialSampler(10, pygrain.ShardOptions(0, 1)),
252+
operations=[pygrain.Batch(1)],
253+
)
254+
data_iter = iter(dl)
255+
for _ in range(3):
256+
next(data_iter)
257+
258+
arr = jax.device_put(np.arange(self.mesh.size, dtype=np.int32))
259+
state = {'a': arr}
260+
save_args = p2p_args_lib.Composite(
261+
state=args_lib.PyTreeSave(state),
262+
data_iter=pygrain.PyGrainCheckpointSave(data_iter),
263+
)
264+
manager.save(1, args=save_args)
265+
manager.wait_until_finished()
266+
267+
new_dl = pygrain.DataLoader(
268+
data_source=ds,
269+
sampler=pygrain.SequentialSampler(10, pygrain.ShardOptions(0, 1)),
270+
operations=[pygrain.Batch(1)],
271+
)
272+
new_data_iter = iter(new_dl)
273+
# PersistentCheckpointManager expects the state with sharding information
274+
# in args.state.
275+
restore_args = p2p_args_lib.Composite(
276+
state=state,
277+
data_iter=pygrain.PyGrainCheckpointRestore(new_data_iter),
278+
)
279+
restored = manager.restore(1, args=restore_args)
280+
281+
self.assertIn('state', restored)
282+
self.assertIn('data_iter', restored)
283+
test_utils.assert_tree_equal(self, state, restored['state'])
284+
self.assertEqual(next(restored['data_iter']), 3)
285+
manager.close()
286+
223287
def test_delete_in_primary_slice_deletes(self):
224288
self._patch_process_index(process_index=0)
225289
manager = persistent.PersistentCheckpointManager(

checkpoint/orbax/checkpoint/experimental/emergency/p2p/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,27 @@
1414

1515
"""Utils for P2P checkpointing."""
1616

17+
from typing import Any
18+
1719
from absl import logging
1820
from etils import epath
21+
1922
from orbax.checkpoint.experimental.emergency.p2p import constants
2023

24+
# pytype:disable=import-error
25+
# pylint:disable=g-import-not-at-top
26+
try:
27+
import grain.python as pygrain_module
28+
except ImportError:
29+
pygrain_module = None
30+
# pytype:enable=import-error
31+
# pylint:enable=g-import-not-at-top
32+
33+
34+
def pygrain() -> Any | None:
35+
"""Returns the grain.python module if available, otherwise None."""
36+
return pygrain_module
37+
2138

2239
def detect_process_index(directory: epath.Path, step: int) -> int | None:
2340
"""Inspects the disk to find which process index created this step."""

checkpoint/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ testing = [
7272
'tensorflow',
7373
'fastapi',
7474
'httpx',
75+
'grain',
7576
]
7677

7778
[tool.flit.sdist]

0 commit comments

Comments
 (0)