Skip to content

Commit b19ddc2

Browse files
committed
refactor pytree args
1 parent b6689b4 commit b19ddc2

File tree

3 files changed

+95
-2
lines changed

3 files changed

+95
-2
lines changed

checkpoint/orbax/checkpoint/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from orbax.checkpoint.checkpoint_manager import CheckpointManager
5555
from orbax.checkpoint.checkpoint_manager import AsyncOptions
5656
from orbax.checkpoint.checkpoint_manager import CheckpointManagerOptions
57+
from orbax.checkpoint.checkpoint_manager import PyTreeOptions
5758

5859
from orbax.checkpoint._src.handlers.pytree_checkpoint_handler import RestoreArgs
5960
from orbax.checkpoint._src.handlers.pytree_checkpoint_handler import ArrayRestoreArgs

checkpoint/orbax/checkpoint/checkpoint_manager.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import threading
2323
import time
2424
import typing
25+
import numpy as np
2526
from typing import Any, Callable, Container, Iterable, List, Mapping, Optional, Sequence, Tuple, Type, Union, overload
2627

2728
from absl import logging
@@ -64,6 +65,12 @@
6465
from orbax.checkpoint._src.path import utils as path_utils
6566
from typing_extensions import Self # for Python version < 3.11
6667

68+
from orbax.checkpoint import type_handlers
69+
from orbax.checkpoint._src.handlers import pytree_checkpoint_handler
70+
from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib
71+
from orbax.checkpoint._src.serialization.type_handlers import PLACEHOLDER
72+
from orbax.checkpoint._src.serialization.type_handlers import PlaceholderHandler
73+
6774

6875

6976
PyTree = Any
@@ -96,6 +103,7 @@
96103
AsyncOptions = options_lib.AsyncOptions
97104
MultiprocessingOptions = options_lib.MultiprocessingOptions
98105
FileOptions = options_lib.FileOptions
106+
PyTreeOptions = options_lib.PyTreeOptions
99107

100108
DEFAULT_ITEM_NAME = 'default'
101109
METRIC_ITEM_NAME = 'metrics'
@@ -367,6 +375,10 @@ class CheckpointManagerOptions:
367375
supposed to be created per process. This is used to support async
368376
directory creation. If True, `multiprocessing_options.primary_host` must be
369377
None.
378+
pytree_options: PyTreeOptions instance to configure PyTree checkpointing
379+
behavior, including array handler options like use_replica_parallel and
380+
enable_pinned_host_transfer. If not provided, default values will be used.
381+
See PyTreeOptions for more details.
370382
"""
371383

372384
save_interval_steps: int = 1
@@ -391,6 +403,7 @@ class CheckpointManagerOptions:
391403
read_only: bool = False
392404
enable_async_checkpointing: bool = True
393405
async_options: Optional[AsyncOptions] = None
406+
pytree_options: Optional[PyTreeOptions] = None
394407
multiprocessing_options: MultiprocessingOptions = dataclasses.field(
395408
default_factory=MultiprocessingOptions
396409
)
@@ -712,6 +725,7 @@ def __init__(
712725

713726
self._options = options or CheckpointManagerOptions()
714727
self._multiprocessing_options = self._options.multiprocessing_options
728+
self._pytree_options = self._options.pytree_options or PyTreeOptions()
715729

716730
if self._options.enable_per_process_directory_creation:
717731
future.AwaitableSignalsContract.awaitable_signals_contract_prefix += (
@@ -1065,12 +1079,23 @@ def _configure_checkpointer_from_item_names_and_handlers(
10651079
for item_name, handler in item_handlers.items():
10661080
all_item_handlers[item_name] = handler
10671081

1068-
for item_name in all_item_handlers:
1082+
pytree_options = self._pytree_options
1083+
set_handlers = {}
1084+
for item_name, handler in all_item_handlers.items():
10691085
if item_name in RESERVED_ITEM_NAMES:
10701086
raise ValueError(
10711087
f'Found {item_name} in `checkpointers`; this is a reserved key.'
10721088
)
1073-
all_item_handlers[METRIC_ITEM_NAME] = self._metrics_handler
1089+
if handler is None:
1090+
set_handlers[item_name] = self._create_default_pytree_handler(
1091+
options, pytree_options
1092+
)
1093+
else:
1094+
set_handlers[item_name] = handler
1095+
1096+
set_handlers[METRIC_ITEM_NAME] = self._metrics_handler
1097+
all_item_handlers = set_handlers
1098+
10741099
# CompositeCheckpointHandler defers per-item handler creation until
10751100
# save/restore time.
10761101
async_options = options.async_options or AsyncOptions()
@@ -1090,6 +1115,44 @@ def _configure_checkpointer_from_item_names_and_handlers(
10901115
options.enable_async_checkpointing,
10911116
)
10921117

1118+
def _create_default_pytree_handler(
1119+
self,
1120+
options: CheckpointManagerOptions,
1121+
pytree_options: PyTreeOptions
1122+
) -> pytree_checkpoint_handler.PyTreeCheckpointHandler:
1123+
"""Creates a default pytree handler."""
1124+
custom_array_handler = type_handlers.ArrayHandler(
1125+
primary_host=options.multiprocessing_options.primary_host,
1126+
use_replica_parallel=pytree_options.use_replica_parallel,
1127+
min_slice_bytes_for_replica_parallel=pytree_options.min_slice_bytes_for_replica_parallel,
1128+
max_replicas_for_replica_parallel=pytree_options.max_replicas_for_replica_parallel,
1129+
enable_replica_parallel_separate_folder=pytree_options.enable_replica_parallel_separate_folder,
1130+
array_metadata_store=array_metadata_store_lib.Store(),
1131+
)
1132+
1133+
custom_registry = type_handlers.create_type_handler_registry(
1134+
(int, type_handlers.ScalarHandler()),
1135+
(float, type_handlers.ScalarHandler()),
1136+
(bytes, type_handlers.ScalarHandler()),
1137+
(np.number, type_handlers.ScalarHandler()),
1138+
(np.ndarray, type_handlers.NumpyHandler()),
1139+
(jax.Array, custom_array_handler),
1140+
(str, type_handlers.StringHandler()),
1141+
(type(PLACEHOLDER), PlaceholderHandler()),
1142+
)
1143+
1144+
return pytree_checkpoint_handler.PyTreeCheckpointHandler(
1145+
save_concurrent_gb=pytree_options.save_concurrent_gb,
1146+
restore_concurrent_gb=pytree_options.restore_concurrent_gb,
1147+
save_device_host_concurrent_gb=pytree_options.save_device_host_concurrent_gb,
1148+
use_ocdbt=pytree_options.use_ocdbt,
1149+
use_zarr3=pytree_options.use_zarr3,
1150+
use_compression=pytree_options.use_compression,
1151+
multiprocessing_options=options.multiprocessing_options,
1152+
type_handler_registry=custom_registry,
1153+
enable_pinned_host_transfer=pytree_options.enable_pinned_host_transfer,
1154+
)
1155+
10931156
def _configure_checkpointer_from_handler_registry(
10941157
self,
10951158
handler_registry: CheckpointHandlerRegistry,

checkpoint/orbax/checkpoint/options.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,32 @@ class FileOptions:
7070
"""
7171

7272
path_permission_mode: int | None = None
73+
74+
@dataclasses.dataclass
75+
class PyTreeOptions:
76+
"""Options for PyTree checkpointing behavior
77+
78+
Attributes:
79+
enable_pinned_host_transfer: Whether to use pinned host memory for D2H transfer
80+
use_replica_parallel: Whether to parallelize saving across replicas
81+
min_slice_bytes_for_replica_parallel: Minimum bytes per replica slice
82+
max_replicas_for_replica_parallel: Maximum replicas for parallel saving
83+
enable_replica_parallel_separate_folder: Save replicated/sharded in separate folders
84+
save_concurrent_gb: Max concurrent GB for writing
85+
restore_concurrent_gb: Max concurrent GB for reading
86+
save_device_host_concurrent_gb: Max concurrent GB for D2H transfer
87+
use_ocdbt: Use Tensorstore OCDBT driver
88+
use_zarr3: Use Zarr version 3
89+
use_compression: Use compression (zstd for zarr2)
90+
"""
91+
enable_pinned_host_transfer: Optional[bool] = None
92+
use_replica_parallel: bool = True
93+
min_slice_bytes_for_replica_parallel: Optional[int] = None
94+
max_replicas_for_replica_parallel: Optional[int] = None
95+
enable_replica_parallel_separate_folder: bool = False
96+
save_concurrent_gb: Optional[int] = None
97+
restore_concurrent_gb: Optional[int] = None
98+
save_device_host_concurrent_gb: Optional[int] = None
99+
use_ocdbt: bool = True
100+
use_zarr3: bool = False
101+
use_compression: bool = True

0 commit comments

Comments
 (0)