Skip to content

Commit 66bd1af

Browse files
mxberlotOrbax Authors
authored andcommitted
Internal change
PiperOrigin-RevId: 877831005
1 parent 60b50ba commit 66bd1af

File tree

2 files changed

+66
-0
lines changed

2 files changed

+66
-0
lines changed

checkpoint/orbax/checkpoint/_src/path/atomicity.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454

5555
import abc
5656
import asyncio
57+
import concurrent.futures
5758
import pickle
5859
import threading
5960
import time
@@ -228,6 +229,42 @@ def get_awaitable_path(self) -> path_types.PathAwaitingCreation:
228229
...
229230

230231

232+
class DeferredPath(path_types.PathAwaitingCreation):
233+
"""A path that is created asynchronously and can be awaited.
234+
235+
Uses concurrent.futures.Future instead of asyncio.Task to avoid
236+
event loop binding issues when create() runs in a different thread.
237+
The Future is thread-safe and can be awaited from any event loop.
238+
"""
239+
240+
def __init__(self):
241+
self._future_path: concurrent.futures.Future[epath.Path] = (
242+
concurrent.futures.Future()
243+
)
244+
245+
def set_path(self, path: epath.Path) -> None:
246+
"""Sets the path result. Called by create() when allocation completes."""
247+
self._future_path.set_result(path)
248+
249+
def __truediv__(
250+
self, other: path_types.PathLike
251+
) -> path_types.PathAwaitingCreation:
252+
child = DeferredPath()
253+
self._future_path.add_done_callback(
254+
lambda f: child.set_path(f.result() / other)
255+
)
256+
return child
257+
258+
@property
259+
def path(self) -> epath.Path:
260+
if not self._future_path.done():
261+
raise ValueError('Path has not been created yet. Call await_creation().')
262+
return self._future_path.result()
263+
264+
async def await_creation(self) -> epath.Path:
265+
return await asyncio.wrap_future(self._future_path)
266+
267+
231268
class ReadOnlyTemporaryPath(atomicity_types.TemporaryPath):
232269
"""A read-only, serializable object providing path properties access.
233270

checkpoint/orbax/checkpoint/_src/path/atomicity_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import asyncio
16+
import concurrent.futures
1517
import stat
1618
import unittest
1719
from absl.testing import absltest
@@ -205,6 +207,33 @@ async def test_finalize_raises(self):
205207
)
206208

207209

210+
class DeferredPathTest(absltest.TestCase):
211+
212+
def test_set_and_get_path(self):
213+
dp = atomicity.DeferredPath()
214+
test_path = epath.Path('/test/path')
215+
dp.set_path(test_path)
216+
self.assertEqual(dp.path, test_path)
217+
218+
def test_path_before_set_raises(self):
219+
dp = atomicity.DeferredPath()
220+
with self.assertRaises(ValueError):
221+
_ = dp.path
222+
223+
def test_await_creation(self):
224+
dp = atomicity.DeferredPath()
225+
test_path = epath.Path('/test/path')
226+
dp.set_path(test_path)
227+
result = asyncio.run(dp.await_creation())
228+
self.assertEqual(result, test_path)
229+
230+
def test_set_path_twice_raises(self):
231+
dp = atomicity.DeferredPath()
232+
dp.set_path(epath.Path('/first'))
233+
with self.assertRaises(concurrent.futures.InvalidStateError):
234+
dp.set_path(epath.Path('/second'))
235+
236+
208237

209238
if __name__ == '__main__':
210239
absltest.main()

0 commit comments

Comments
 (0)