Skip to content

Commit 5d84c76

Browse files
cpgaffney1Orbax Authors
authored and
Orbax Authors
committed
Internal change.
PiperOrigin-RevId: 733059068
1 parent 4a24304 commit 5d84c76

20 files changed

+899
-352
lines changed

.github/workflows/build.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jobs:
6060
- name: Test with pytest
6161
# TODO(yaning): Move these to an exclude target within pytest.ini.
6262
run: |
63-
python -m pytest --ignore=orbax/checkpoint/experimental/emergency/broadcast_multislice_test.py --ignore=orbax/checkpoint/experimental/emergency/checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_data_debugging_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/multihost_test.py --ignore=orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager_test.py --ignore=orbax/checkpoint/_src/testing/multiprocess_test.py
63+
python -m pytest --ignore=orbax/checkpoint/experimental/emergency/broadcast_multislice_test.py --ignore=orbax/checkpoint/experimental/emergency/checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/single_slice_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_data_debugging_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/multihost_test.py --ignore=orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager_test.py --ignore=orbax/checkpoint/_src/testing/multiprocess_test.py
6464
# The below step just reports the success or failure of tests as a "commit status".
6565
# This is needed for copybara integration.
6666
- name: Report success or failure as github status

checkpoint/CHANGELOG.md

+7-1
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
### Fixed
11+
12+
- Fix RESOURCE_EXHAUSTED while writing array_metadatas.
13+
1014
### Changed
1115

1216
- Improve `Cannot serialize host local jax.Array` error message.
1317

1418
### Added
1519

16-
- support saving and restoring jax.random.key() in PyTree
20+
- support saving and restoring jax.random.key() in PyTree.
21+
- `CheckpointableHandler` for V1.
22+
- Support single-slice checkpointing in `emergency.CheckpointManager`.
1723

1824
## [0.11.6] - 2025-02-20
1925

checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py

+2
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,8 @@ def __init__(
323323
type_handler_registry
324324
)
325325
)
326+
if self._array_metadata_store:
327+
self._array_metadata_store.set_primary_host(self._primary_host)
326328
self._array_metadata_validator = array_metadata_validator
327329

328330

checkpoint/orbax/checkpoint/_src/metadata/array_metadata_store.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,45 @@ def __init__(
133133
self,
134134
path_resolver: PathResolver = PathResolver(),
135135
ser_deser: SerDeserializer = SerDeserializer(),
136+
primary_host: int | None = 0, # None means all hosts are primary hosts.
137+
write_timeout_secs: int = 600, # 10 minutes.
136138
):
137139
self._path_resolver = path_resolver
138140
self._ser_deser = ser_deser
141+
self._primary_host = primary_host
142+
self._write_timeout_secs = write_timeout_secs
143+
144+
def set_primary_host(self, primary_host: int | None) -> None:
145+
"""Sets the primary host."""
146+
self._primary_host = primary_host
147+
148+
async def _maybe_create_base_dir(self, base_dir: epath.Path) -> None:
149+
"""Primary host creates the base directory, rest of the hosts wait."""
150+
if multihost.is_primary_host(self._primary_host):
151+
# primary host creates, rest of the hosts wait.
152+
return await asyncio.to_thread(
153+
base_dir.mkdir, parents=True, exist_ok=True
154+
)
155+
156+
# non-primary host waits for primary host to create the base dir/folder.
157+
async def wait_for_base_dir_creation():
158+
while not await asyncio.to_thread(base_dir.exists):
159+
await asyncio.sleep(0.25)
160+
161+
try:
162+
await asyncio.wait_for(
163+
wait_for_base_dir_creation(), timeout=self._write_timeout_secs
164+
)
165+
except asyncio.TimeoutError as e:
166+
primary_process = (
167+
'LOCAL' if self._primary_host is None else self._primary_host
168+
)
169+
raise ValueError(
170+
f'[process_index={multihost.process_index()}] Timed out waiting for'
171+
f' array_metadatas base directory creation: {base_dir}.'
172+
f' timeout={self._write_timeout_secs} seconds.'
173+
f' primary_process={primary_process}'
174+
) from e
139175

140176
async def write(
141177
self,
@@ -155,7 +191,7 @@ async def write(
155191
file_path = self._path_resolver.get_write_file_path(
156192
checkpoint_dir, process_index
157193
)
158-
await asyncio.to_thread(file_path.parent.mkdir, parents=True, exist_ok=True)
194+
await self._maybe_create_base_dir(file_path.parent)
159195
await asyncio.to_thread(
160196
file_path.write_text, self._ser_deser.serialize(array_metadatas)
161197
)

checkpoint/orbax/checkpoint/_src/multihost/multislice.py

+2
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ def slice_count(
7676
global_mesh: jax.sharding.Mesh, *, replica_axis_index: int = 0
7777
) -> int:
7878
"""Number of slices implied by the mesh's replica dimension."""
79+
if len(global_mesh.shape_tuple) == 1:
80+
return 1
7981
return global_mesh.devices.shape[replica_axis_index]
8082

8183

checkpoint/orbax/checkpoint/abstract_checkpoint_manager.py

+12-169
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,11 @@ def best_step(self) -> Optional[int]:
8585

8686
@abc.abstractmethod
8787
def reload(self):
88-
"""Performs disk reads to ensure internal properties are up to date."""
88+
"""Reloads internal properties.
89+
90+
Resets internal cache of checkpoint steps, in case the directory managed
91+
by this object has been updated externally.
92+
"""
8993

9094
@abc.abstractmethod
9195
def reached_preemption(self, step: int) -> bool:
@@ -112,186 +116,25 @@ def delete(self, step: int):
112116
def save(
113117
self,
114118
step: int,
115-
items: Optional[Union[Any, Mapping[str, Any]]] = None,
116-
save_kwargs: Optional[Union[SaveParams, Mapping[str, SaveParams]]] = None,
117-
metrics: Optional[PyTree] = None,
118-
force: Optional[bool] = False,
119-
args: Optional[args_lib.CheckpointArgs] = None,
120-
custom_metadata: dict[str, Any] | None = None,
119+
*args,
120+
**kwargs,
121121
) -> bool:
122-
"""Saves the provided items.
123-
124-
This method should be called by all hosts - process synchronization and
125-
actions that need to be performed on only one host are managed internally.
126-
127-
NOTE: The `items` and `save_kwargs` arguments are deprecated, use `args`
128-
instead. Make sure to configure `CheckpointManager` with `item_names`.
129-
130-
`args` should be a subclass of
131-
`orbax.checkpoint.args.CheckpointArgs`, the specific type of which is used
132-
to indicate what logic is used to save the object. For a typical, PyTree of
133-
arrays, use `StandardSave`/`StandardRestore`.
134-
135-
When constructing the `CheckpointManager`, if no `item_names` were provided,
136-
it is assumed that we are managing a single object. If `item_names` were
137-
provided, it is assumed that we are managing multiple objects, and `args`
138-
must be `orbax.checkpoint.args.CompositeArgs`. See below for details.
139-
140-
Example::
141-
142-
# Single item
143-
mngr = ocp.CheckpointManager(directory)
144-
mngr.save(step, args=ocp.args.StandardSave(my_train_state))
145-
146-
# Multiple items
147-
mngr = ocp.CheckpointManager(directory, item_names=('state', 'meta'))
148-
mngr.save(step, args=ocp.args.Composite(
149-
state=ocp.args.StandardSave(my_train_state),
150-
meta=ocp.args.JsonSave(my_metadata)
151-
))
152-
153-
Args:
154-
step: current step, int
155-
items: a savable object, or a dictionary of object name to savable object.
156-
save_kwargs: save kwargs for a single Checkpointer, or a dictionary of
157-
object name to kwargs needed by the Checkpointer implementation to save
158-
the object.
159-
metrics: a dictionary of metric name (string) to numeric value to be
160-
tracked along with this checkpoint. Required if `options.best_fn` is
161-
set. Allows users to specify a metric value to determine which
162-
checkpoints are best and should be kept (in conjunction with
163-
`options.max_to_keep`).
164-
force: if `True`, this method will attempt to save a checkpoint regardless
165-
of the result of `AbstractCheckpointManager.should_save(step)`. By
166-
default, `save` will only write a checkpoint to disk when the options
167-
permit, e.g. when `step` is in `options.save_interval_steps` or
168-
`options.save_on_steps`. Setting `force=True` will not overwrite
169-
existing checkpoints.
170-
args: `CheckpointArgs` which is used to save checkpointable objects with
171-
the appropriate logic.
172-
custom_metadata: a dictionary of custom metadata to be written to the
173-
checkpoint directory via StepMetadata.
174-
175-
Returns:
176-
bool indicating whether a save operation was performed.
177-
Raises:
178-
ValueError: if `track_best` was indicated but `metrics` is not provided.
179-
ValueError: directory creation failed.
180-
ValueError: if an item is provided for which no `Checkpointer` is
181-
found.
182-
ValueError: if the checkpoint already exists.
183-
"""
122+
"""Saves the given step."""
184123

185124
@abc.abstractmethod
186125
def restore(
187126
self,
188127
step: Optional[int],
189-
items: Optional[Union[Any, Mapping[str, Any]]] = None,
190-
restore_kwargs: Optional[
191-
Union[RestoreParams, Mapping[str, RestoreParams]]
192-
] = None,
193-
directory: Optional[epath.PathLike] = None,
194-
args: Optional[args_lib.CheckpointArgs] = None,
128+
*args,
129+
**kwargs,
195130
) -> Union[Any, Mapping[str, Any], args_lib.Composite]:
196-
"""Restores from the given step and provided items.
197-
198-
This method should be called by all hosts - process synchronization and
199-
actions that need to be performed on only one host are managed internally.
200-
201-
NOTE: The `items` and `restore_kwargs` arguments are deprecated, use `args`
202-
instead. Make sure to configure `CheckpointManager` with `item_names`.
203-
See `save` docstring for additional details.
204-
205-
Example::
206-
207-
# Single item
208-
mngr = ocp.CheckpointManager(directory)
209-
mngr.restore(step, args=ocp.args.StandardRestore(abstract_train_state))
210-
211-
# Multiple items
212-
mngr = ocp.CheckpointManager(directory, item_names=('state', 'meta'))
213-
mngr.restore(step, args=ocp.args.Composite(
214-
state=ocp.args.StandardRestore(abstract_train_state),
215-
meta=ocp.args.JsonRestore(),
216-
))
217-
# If it is acceptable to restore without providing additional arguments,
218-
# and if a save has already been performed, it is ok to do the following:
219-
mngr.restore(step, args=ocp.args.Composite(state=None, meta=None))
220-
# If a save has not already been performed, there is no way for Orbax to
221-
# know how to restore the objects. If a save has already been performed,
222-
# it remembers the logic used to save the objects.
223-
224-
Args:
225-
step: current step, int
226-
items: a restoreable object, or a dictionary of object name to restorable
227-
object.
228-
restore_kwargs: restore kwargs for a single Checkpointer, or a dictionary
229-
of object name to kwargs needed by the Checkpointer implementation to
230-
restore the object.
231-
directory: if provided, uses the given directory rather than the
232-
`directory` property of this class. Can be used to restore checkpoints
233-
from an independent location.
234-
args: `CheckpointArgs` which is used to restore checkpointable objects
235-
with the appropriate logic.
236-
237-
Returns:
238-
If managing a single item, returns a single checkpointable object.
239-
If managing multiple items, returns ocp.args.Composite, where the keys
240-
are item names, and values are checkpointable objects.
241-
"""
131+
"""Restores the given step."""
242132

243133
@abc.abstractmethod
244134
def item_metadata(
245135
self, step: int
246136
) -> Union[Any, Mapping[str, Any], args_lib.Composite]:
247-
"""For all Checkpointers, returns any metadata associated with the item.
248-
249-
Calls the `metadata` method for each Checkpointer and returns a
250-
mapping of each item name to the restored metadata. If the manager only
251-
manages a single item, a single metadata will be returned instead.
252-
253-
To avoid errors due to missing CheckpointHandlers, concrete
254-
CheckpointManager constructor must allow mapping from item names to
255-
respective CheckpointHandlers to be input other than via save() and
256-
restore(). Please note that save() and restore() calls automatically
257-
map CheckpointHandlers to respective item names and retain it during the
258-
lifetime of the CheckpointManager instance.
259-
260-
Example::
261-
262-
# Single item
263-
mngr = ocp.CheckpointManager(directory)
264-
# No calls to save() or restore() before calling item_metadata().
265-
mngr.item_metadata(step) # Raises error.
266-
267-
mngr = ocp.CheckpointManager(directory,
268-
item_handlers=ocp.StandardCheckpointHandler)
269-
# No calls to save() or restore() before calling item_metadata().
270-
metadata = mngr.item_metadata(step) # Successful.
271-
272-
# Multiple items
273-
mngr = ocp.CheckpointManager(directory, item_names=('state', 'extra'))
274-
# No calls to save() or restore() before calling item_metadata().
275-
mngr.item_metadata(step) # Raises error.
276-
277-
mngr = ocp.CheckpointManager(directory,
278-
item_names=('state', 'extra'),
279-
item_handlers={
280-
'state': ocp.StandardCheckpointHandler,
281-
'extra': ocp.PytreeCheckpointHandler,
282-
}
283-
)
284-
# No calls to save() or restore() before calling item_metadata().
285-
metadata = mngr.item_metadata(step) # Successful.
286-
287-
Metadata may be None for an individual item.
288-
289-
Args:
290-
step: Step for which to retrieve metadata.
291-
292-
Returns:
293-
A dictionary mapping name to item metadata, or a single item metadata.
294-
"""
137+
"""Returns metadata for all known items."""
295138

296139
@abc.abstractmethod
297140
def metadata(

0 commit comments

Comments
 (0)