Skip to content

Commit 30d40ed

Browse files
mxberlotOrbax Authors
authored andcommitted
Internal change
PiperOrigin-RevId: 876100881
1 parent 99bfb4b commit 30d40ed

File tree

2 files changed

+51
-2
lines changed

2 files changed

+51
-2
lines changed

checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib
4545
from orbax.checkpoint._src.metadata import empty_values
4646
from orbax.checkpoint._src.metadata import tree as tree_metadata
47+
from orbax.checkpoint._src.path import types as path_types
4748
from orbax.checkpoint._src.serialization import limits
4849
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
4950
from orbax.checkpoint._src.serialization import type_handler_registry as handler_registry
@@ -470,7 +471,9 @@ def _concurrent_bytes(
470471
return concurrent_gb * 10**9
471472

472473

473-
class PyTreeCheckpointHandler(async_checkpoint_handler.AsyncCheckpointHandler):
474+
class PyTreeCheckpointHandler(
475+
async_checkpoint_handler.DeferredPathAsyncCheckpointHandler
476+
):
474477
"""A CheckpointHandler implementation for any PyTree structure.
475478
476479
See JAX documentation for more information on what consistutes a "PyTree".
@@ -608,7 +611,7 @@ def __init__(
608611

609612
async def async_save(
610613
self,
611-
directory: epath.Path,
614+
directory: epath.Path | path_types.PathAwaitingCreation,
612615
item: Optional[PyTree] = None,
613616
save_args: Optional[PyTreeSaveArgs] = None,
614617
args: Optional[PyTreeSaveArgs] = None,

checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler_test.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from orbax.checkpoint._src.serialization import type_handlers
6262
from orbax.checkpoint._src.testing import multiprocess_test
6363
from orbax.checkpoint._src.tree import utils as tree_utils
64+
from orbax.checkpoint.google.path import tfhub_atomicity
6465

6566

6667
PyTree = Any
@@ -2948,6 +2949,51 @@ def test_partial_restore_with_omission_unexpected_keys(
29482949
)
29492950
test_utils.assert_tree_equal(self, expected, restored)
29502951

2952+
async def test_save_with_deferred_path(self):
2953+
"""Tests that async_save works with deferred paths."""
2954+
deferred_path = tfhub_atomicity.DeferredPath()
2955+
save_dir = self.directory / 'deferred_path_ckpt'
2956+
await_creation_called = False
2957+
original_await = tfhub_atomicity.DeferredPath.await_creation
2958+
2959+
async def mock_await_creation(dp_self):
2960+
"""Sets the path only once await_creation is called.
2961+
2962+
This ensures the path is not resolved before the handler awaits it, fully
2963+
exercising the deferred path resolution contract.
2964+
2965+
Args:
2966+
dp_self: The DeferredPath instance.
2967+
2968+
Returns:
2969+
The result of the original await_creation method.
2970+
"""
2971+
nonlocal await_creation_called
2972+
if not dp_self._future_path.done():
2973+
save_dir.mkdir(parents=True, exist_ok=True)
2974+
dp_self.set_path(save_dir)
2975+
await_creation_called = True
2976+
return await original_await(dp_self)
2977+
2978+
with self.ocdbt_checkpoint_handler(use_ocdbt=False) as handler:
2979+
with mock.patch.object(
2980+
tfhub_atomicity.DeferredPath,
2981+
'await_creation',
2982+
mock_await_creation,
2983+
):
2984+
await handler.async_save(
2985+
deferred_path,
2986+
args=PyTreeSaveArgs(self.pytree),
2987+
)
2988+
2989+
self.assertTrue(await_creation_called)
2990+
self.validate_save(
2991+
save_dir,
2992+
self.pytree,
2993+
handler,
2994+
restore_args=self.restore_args,
2995+
)
2996+
29512997

29522998
if __name__ == '__main__':
29532999
multiprocess_test.main()

0 commit comments

Comments
 (0)