Skip to content

Commit 41b6356

Browse files
mxberlotOrbax Authors
authored andcommitted
Internal change
PiperOrigin-RevId: 876100881
1 parent 99bfb4b commit 41b6356

File tree

5 files changed

+126
-2
lines changed

5 files changed

+126
-2
lines changed

checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib
4545
from orbax.checkpoint._src.metadata import empty_values
4646
from orbax.checkpoint._src.metadata import tree as tree_metadata
47+
from orbax.checkpoint._src.path import types as path_types
4748
from orbax.checkpoint._src.serialization import limits
4849
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
4950
from orbax.checkpoint._src.serialization import type_handler_registry as handler_registry
@@ -470,7 +471,9 @@ def _concurrent_bytes(
470471
return concurrent_gb * 10**9
471472

472473

473-
class PyTreeCheckpointHandler(async_checkpoint_handler.AsyncCheckpointHandler):
474+
class PyTreeCheckpointHandler(
475+
async_checkpoint_handler.DeferredPathAsyncCheckpointHandler
476+
):
474477
"""A CheckpointHandler implementation for any PyTree structure.
475478
476479
See JAX documentation for more information on what consistutes a "PyTree".
@@ -608,7 +611,7 @@ def __init__(
608611

609612
async def async_save(
610613
self,
611-
directory: epath.Path,
614+
directory: epath.Path | path_types.PathAwaitingCreation,
612615
item: Optional[PyTree] = None,
613616
save_args: Optional[PyTreeSaveArgs] = None,
614617
args: Optional[PyTreeSaveArgs] = None,

checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler_test.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import functools
2222
import json
2323
import re
24+
import threading
2425
from typing import Any, Iterator, List, NamedTuple, Optional, Sequence
2526
import unittest
2627
from unittest import mock
@@ -54,6 +55,7 @@
5455
from orbax.checkpoint._src.metadata import tree as tree_metadata
5556
from orbax.checkpoint._src.metadata import value as value_metadata
5657
from orbax.checkpoint._src.multihost import multihost
58+
from orbax.checkpoint._src.path import atomicity
5759
from orbax.checkpoint._src.serialization import limits
5860
from orbax.checkpoint._src.serialization import replica_slices
5961
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
@@ -2948,6 +2950,56 @@ def test_partial_restore_with_omission_unexpected_keys(
29482950
)
29492951
test_utils.assert_tree_equal(self, expected, restored)
29502952

2953+
async def test_save_with_deferred_path(self):
2954+
"""Tests that async_save works with deferred paths."""
2955+
deferred_path = atomicity.DeferredPath()
2956+
save_dir = self.directory / 'deferred_path_ckpt'
2957+
await_creation_called = False
2958+
original_await = atomicity.DeferredPath.await_creation
2959+
set_path_lock = threading.Lock()
2960+
2961+
async def mock_await_creation(dp_self):
2962+
"""Sets the path only once await_creation is called.
2963+
2964+
This ensures the path is not resolved before the handler awaits it, fully
2965+
exercising the deferred path resolution contract.
2966+
2967+
Args:
2968+
dp_self: The DeferredPath instance.
2969+
2970+
Returns:
2971+
The result of the original await_creation method.
2972+
"""
2973+
nonlocal await_creation_called
2974+
with set_path_lock:
2975+
if not dp_self._future_path.done():
2976+
save_dir.mkdir(parents=True, exist_ok=True)
2977+
dp_self.set_path(save_dir)
2978+
await_creation_called = True
2979+
return await original_await(dp_self)
2980+
2981+
with self.ocdbt_checkpoint_handler(use_ocdbt=False) as handler:
2982+
with mock.patch.object(
2983+
atomicity.DeferredPath,
2984+
'await_creation',
2985+
mock_await_creation,
2986+
):
2987+
commit_futures = await handler.async_save(
2988+
deferred_path,
2989+
args=PyTreeSaveArgs(self.pytree),
2990+
)
2991+
if commit_futures:
2992+
for f in commit_futures:
2993+
f.result()
2994+
2995+
self.assertTrue(await_creation_called)
2996+
self.validate_save(
2997+
save_dir,
2998+
self.pytree,
2999+
handler,
3000+
restore_args=self.restore_args,
3001+
)
3002+
29513003

29523004
if __name__ == '__main__':
29533005
multiprocess_test.main()

checkpoint/orbax/checkpoint/_src/path/atomicity.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454

5555
import abc
5656
import asyncio
57+
import concurrent.futures
5758
import pickle
5859
import threading
5960
import time
@@ -228,6 +229,42 @@ def get_awaitable_path(self) -> path_types.PathAwaitingCreation:
228229
...
229230

230231

232+
class DeferredPath(path_types.PathAwaitingCreation):
233+
"""A path that is created asynchronously and can be awaited.
234+
235+
Uses concurrent.futures.Future instead of asyncio.Task to avoid
236+
event loop binding issues when create() runs in a different thread.
237+
The Future is thread-safe and can be awaited from any event loop.
238+
"""
239+
240+
def __init__(self):
241+
self._future_path: concurrent.futures.Future[epath.Path] = (
242+
concurrent.futures.Future()
243+
)
244+
245+
def set_path(self, path: epath.Path) -> None:
246+
"""Sets the path result. Called by create() when allocation completes."""
247+
self._future_path.set_result(path)
248+
249+
def __truediv__(
250+
self, other: path_types.PathLike
251+
) -> path_types.PathAwaitingCreation:
252+
child = DeferredPath()
253+
self._future_path.add_done_callback(
254+
lambda f: child.set_path(f.result() / other)
255+
)
256+
return child
257+
258+
@property
259+
def path(self) -> epath.Path:
260+
if not self._future_path.done():
261+
raise ValueError('Path has not been created yet. Call await_creation().')
262+
return self._future_path.result()
263+
264+
async def await_creation(self) -> epath.Path:
265+
return await asyncio.wrap_future(self._future_path)
266+
267+
231268
class ReadOnlyTemporaryPath(atomicity_types.TemporaryPath):
232269
"""A read-only, serializable object providing path properties access.
233270

checkpoint/orbax/checkpoint/_src/path/atomicity_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import asyncio
16+
import concurrent.futures
1517
import stat
1618
import unittest
1719
from absl.testing import absltest
@@ -205,6 +207,33 @@ async def test_finalize_raises(self):
205207
)
206208

207209

210+
class DeferredPathTest(absltest.TestCase):
211+
212+
def test_set_and_get_path(self):
213+
dp = atomicity.DeferredPath()
214+
test_path = epath.Path('/test/path')
215+
dp.set_path(test_path)
216+
self.assertEqual(dp.path, test_path)
217+
218+
def test_path_before_set_raises(self):
219+
dp = atomicity.DeferredPath()
220+
with self.assertRaises(ValueError):
221+
_ = dp.path
222+
223+
def test_await_creation(self):
224+
dp = atomicity.DeferredPath()
225+
test_path = epath.Path('/test/path')
226+
dp.set_path(test_path)
227+
result = asyncio.run(dp.await_creation())
228+
self.assertEqual(result, test_path)
229+
230+
def test_set_path_twice_raises(self):
231+
dp = atomicity.DeferredPath()
232+
dp.set_path(epath.Path('/first'))
233+
with self.assertRaises(concurrent.futures.InvalidStateError):
234+
dp.set_path(epath.Path('/second'))
235+
236+
208237

209238
if __name__ == '__main__':
210239
absltest.main()

checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,10 @@ def _serialize_batch(
516516
' scheduled asynchronously.'
517517
)
518518

519+
all_infos = infos
519520
async def _serialize():
521+
for info in all_infos:
522+
await info.await_path_creation()
520523
if prioritized:
521524
arrays, infos, args = zip(*prioritized)
522525
_serialize_batch(infos, args, arrays)

0 commit comments

Comments
 (0)