2222import threading
2323import time
2424import typing
25+ import numpy as np
2526from typing import Any , Callable , Container , Iterable , List , Mapping , Optional , Sequence , Tuple , Type , Union , overload
2627
2728from absl import logging
6465from orbax .checkpoint ._src .path import utils as path_utils
6566from typing_extensions import Self # for Python version < 3.11
6667
68+ from orbax .checkpoint import type_handlers
69+ from orbax .checkpoint ._src .handlers import pytree_checkpoint_handler
70+ from orbax .checkpoint ._src .metadata import array_metadata_store as array_metadata_store_lib
71+ from orbax .checkpoint ._src .serialization .type_handlers import PLACEHOLDER
72+ from orbax .checkpoint ._src .serialization .type_handlers import PlaceholderHandler
73+
6774
6875
6976PyTree = Any
96103AsyncOptions = options_lib .AsyncOptions
97104MultiprocessingOptions = options_lib .MultiprocessingOptions
98105FileOptions = options_lib .FileOptions
106+ PyTreeOptions = options_lib .PyTreeOptions
99107
100108DEFAULT_ITEM_NAME = 'default'
101109METRIC_ITEM_NAME = 'metrics'
@@ -367,6 +375,10 @@ class CheckpointManagerOptions:
367375 supposed to be created per process. This is used to support async
368376 directory creation. If True, `multiprocessing_options.primary_host` must be
369377 None.
378+ pytree_options: PyTreeOptions instance to configure PyTree checkpointing
379+ behavior, including array handler options like use_replica_parallel and
380+ enable_pinned_host_transfer. If not provided, default values will be used.
381+ See PyTreeOptions for more details.
370382 """
371383
372384 save_interval_steps : int = 1
@@ -391,6 +403,7 @@ class CheckpointManagerOptions:
391403 read_only : bool = False
392404 enable_async_checkpointing : bool = True
393405 async_options : Optional [AsyncOptions ] = None
406+ pytree_options : Optional [PyTreeOptions ] = None
394407 multiprocessing_options : MultiprocessingOptions = dataclasses .field (
395408 default_factory = MultiprocessingOptions
396409 )
@@ -712,6 +725,7 @@ def __init__(
712725
713726 self ._options = options or CheckpointManagerOptions ()
714727 self ._multiprocessing_options = self ._options .multiprocessing_options
728+ self ._pytree_options = self ._options .pytree_options or PyTreeOptions ()
715729
716730 if self ._options .enable_per_process_directory_creation :
717731 future .AwaitableSignalsContract .awaitable_signals_contract_prefix += (
@@ -1065,12 +1079,23 @@ def _configure_checkpointer_from_item_names_and_handlers(
10651079 for item_name , handler in item_handlers .items ():
10661080 all_item_handlers [item_name ] = handler
10671081
1068- for item_name in all_item_handlers :
1082+ pytree_options = self ._pytree_options
1083+ set_handlers = {}
1084+ for item_name , handler in all_item_handlers .items ():
10691085 if item_name in RESERVED_ITEM_NAMES :
10701086 raise ValueError (
10711087 f'Found { item_name } in `checkpointers`; this is a reserved key.'
10721088 )
1073- all_item_handlers [METRIC_ITEM_NAME ] = self ._metrics_handler
1089+ if handler is None :
1090+ set_handlers [item_name ] = self ._create_default_pytree_handler (
1091+ options , pytree_options
1092+ )
1093+ else :
1094+ set_handlers [item_name ] = handler
1095+
1096+ set_handlers [METRIC_ITEM_NAME ] = self ._metrics_handler
1097+ all_item_handlers = set_handlers
1098+
10741099 # CompositeCheckpointHandler defers per-item handler creation until
10751100 # save/restore time.
10761101 async_options = options .async_options or AsyncOptions ()
@@ -1090,6 +1115,44 @@ def _configure_checkpointer_from_item_names_and_handlers(
10901115 options .enable_async_checkpointing ,
10911116 )
10921117
1118+ def _create_default_pytree_handler (
1119+ self ,
1120+ options : CheckpointManagerOptions ,
1121+ pytree_options : PyTreeOptions
1122+ ) -> pytree_checkpoint_handler .PyTreeCheckpointHandler :
1123+ """Creates a default pytree handler."""
1124+ custom_array_handler = type_handlers .ArrayHandler (
1125+ primary_host = options .multiprocessing_options .primary_host ,
1126+ use_replica_parallel = pytree_options .use_replica_parallel ,
1127+ min_slice_bytes_for_replica_parallel = pytree_options .min_slice_bytes_for_replica_parallel ,
1128+ max_replicas_for_replica_parallel = pytree_options .max_replicas_for_replica_parallel ,
1129+ enable_replica_parallel_separate_folder = pytree_options .enable_replica_parallel_separate_folder ,
1130+ array_metadata_store = array_metadata_store_lib .Store (),
1131+ )
1132+
1133+ custom_registry = type_handlers .create_type_handler_registry (
1134+ (int , type_handlers .ScalarHandler ()),
1135+ (float , type_handlers .ScalarHandler ()),
1136+ (bytes , type_handlers .ScalarHandler ()),
1137+ (np .number , type_handlers .ScalarHandler ()),
1138+ (np .ndarray , type_handlers .NumpyHandler ()),
1139+ (jax .Array , custom_array_handler ),
1140+ (str , type_handlers .StringHandler ()),
1141+ (type (PLACEHOLDER ), PlaceholderHandler ()),
1142+ )
1143+
1144+ return pytree_checkpoint_handler .PyTreeCheckpointHandler (
1145+ save_concurrent_gb = pytree_options .save_concurrent_gb ,
1146+ restore_concurrent_gb = pytree_options .restore_concurrent_gb ,
1147+ save_device_host_concurrent_gb = pytree_options .save_device_host_concurrent_gb ,
1148+ use_ocdbt = pytree_options .use_ocdbt ,
1149+ use_zarr3 = pytree_options .use_zarr3 ,
1150+ use_compression = pytree_options .use_compression ,
1151+ multiprocessing_options = options .multiprocessing_options ,
1152+ type_handler_registry = custom_registry ,
1153+ enable_pinned_host_transfer = pytree_options .enable_pinned_host_transfer ,
1154+ )
1155+
10931156 def _configure_checkpointer_from_handler_registry (
10941157 self ,
10951158 handler_registry : CheckpointHandlerRegistry ,
0 commit comments