Skip to content

Commit c339d50

Browse files
JustinPan-googOrbax Authors
authored andcommitted
Internal Change
PiperOrigin-RevId: 870544311
1 parent cbdd63f commit c339d50

File tree

11 files changed

+38
-30
lines changed

11 files changed

+38
-30
lines changed

checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1044,7 +1044,7 @@ def finalize(self, directory: epath.Path):
10441044
# Not an error, as some items may not have been saved.
10451045
continue
10461046
handler.finalize(tmp_dir.get())
1047-
asyncio.run(
1047+
asyncio_utils.run_sync(
10481048
tmp_dir.finalize(
10491049
)
10501050
)

checkpoint/orbax/checkpoint/checkpoint_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
"""High-level checkpoint utils provided for user convenience."""
1616

17-
import asyncio
1817
import contextlib
1918
import time
2019
from typing import Any, Callable, Iterator, Optional
@@ -25,6 +24,7 @@
2524
from jax.experimental import layout
2625
import numpy as np
2726
from orbax.checkpoint import utils
27+
from orbax.checkpoint._src import asyncio_utils
2828
from orbax.checkpoint._src.arrays import sharding as arrays_sharding_lib
2929
from orbax.checkpoint._src.metadata import tree as tree_metadata
3030
from orbax.checkpoint._src.metadata import value as value_metadata
@@ -121,7 +121,7 @@ def _snapshot_checkpoint(
121121
f'Ignoring error when snapshotting checkpoint for step: {step}'
122122
),
123123
):
124-
asyncio.run(snapshot_impl.create_snapshot())
124+
asyncio_utils.run_sync(snapshot_impl.create_snapshot())
125125
return True
126126

127127

@@ -146,7 +146,7 @@ def _release_snapshot(
146146
f'Ignoring error when releasing snapshot for step: {step}'
147147
),
148148
):
149-
asyncio.run(snapshot_impl.release_snapshot())
149+
asyncio_utils.run_sync(snapshot_impl.release_snapshot())
150150

151151

152152
def _reached_desired_step(step: int, until_step: Optional[int]) -> bool:
@@ -409,7 +409,7 @@ def checkpoints_iterator(
409409
f' {step_dir.name}'
410410
),
411411
):
412-
asyncio.run(snapshot_impl.release_snapshot())
412+
asyncio_utils.run_sync(snapshot_impl.release_snapshot())
413413
checkpoint_step = None
414414
while True:
415415
until_step = checkpoint_step + 1 if checkpoint_step is not None else None

checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
always be called across all processes within the primary slice.
2626
"""
2727

28-
import asyncio
2928
import dataclasses
3029
import functools
3130
import time
@@ -40,6 +39,7 @@
4039
from orbax.checkpoint import args as args_lib
4140
from orbax.checkpoint import checkpoint_manager
4241
from orbax.checkpoint import checkpoint_utils
42+
from orbax.checkpoint._src import asyncio_utils
4343
from orbax.checkpoint._src.checkpoint_managers import preservation_policy as preservation_policy_lib
4444
from orbax.checkpoint._src.checkpoint_managers import save_decision_policy as save_decision_policy_lib
4545
from orbax.checkpoint._src.handlers import handler_registration
@@ -991,7 +991,7 @@ def _restore_from_local(
991991
1,
992992
'Debugging single-slice restore_args used for restoration.',
993993
)
994-
asyncio.run(
994+
asyncio_utils.run_sync(
995995
local_checkpoint_data_debugging.print_chunk_debug_info(
996996
restore_directory / _STATE_ITEM_NAME,
997997
single_slice_restore_args,

checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import Any, Awaitable
2020

2121
from absl import logging
22+
from orbax.checkpoint._src import asyncio_utils
2223
from orbax.checkpoint._src.path import async_path
2324
from orbax.checkpoint._src.path import temporary_paths
2425
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
@@ -75,7 +76,7 @@ def is_orbax_v1_checkpoint(path: path_types.PathLike) -> bool:
7576
ctx = context_lib.get_context()
7677
path = ctx.file_options.path_class(path)
7778
try:
78-
asyncio.run(OrbaxLayout().validate(path))
79+
asyncio_utils.run_sync(OrbaxLayout().validate(path))
7980
return True
8081
except InvalidLayoutError:
8182
return False

checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_v0_layout.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import logging
1919
from typing import Any, Awaitable
2020

21+
from orbax.checkpoint._src import asyncio_utils
2122
from orbax.checkpoint._src.metadata import step_metadata_serialization
2223
from orbax.checkpoint._src.path import async_path
2324
from orbax.checkpoint._src.path import temporary_paths
@@ -49,7 +50,7 @@ def is_orbax_v0_checkpoint(path: path_types.PathLike) -> bool:
4950
ctx = context_lib.get_context()
5051
path = ctx.file_options.path_class(path)
5152
try:
52-
asyncio.run(OrbaxV0Layout().validate(path))
53+
asyncio_utils.run_sync(OrbaxV0Layout().validate(path))
5354
return True
5455
except InvalidLayoutError:
5556
return False

checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
"""Registry for checkpoint layouts."""
1616

1717
import asyncio
18+
19+
from orbax.checkpoint._src import asyncio_utils
1820
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
1921
from orbax.checkpoint.experimental.v1._src.context import options as options_lib
2022
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
@@ -25,6 +27,7 @@
2527
from orbax.checkpoint.experimental.v1._src.layout import safetensors_layout
2628
from orbax.checkpoint.experimental.v1._src.path import types as path_types
2729

30+
2831
InvalidLayoutError = checkpoint_layout.InvalidLayoutError
2932
CheckpointLayout = checkpoint_layout.CheckpointLayout
3033
CheckpointLayoutEnum = options_lib.CheckpointLayout
@@ -49,7 +52,7 @@ async def _is_orbax_checkpoint_async(path: path_types.PathLike) -> bool:
4952

5053
def is_orbax_checkpoint(path: path_types.PathLike) -> bool:
5154
"""Returns True if the path is an Orbax checkpoint."""
52-
return asyncio.run(_is_orbax_checkpoint_async(path))
55+
return asyncio_utils.run_sync(_is_orbax_checkpoint_async(path))
5356

5457

5558
async def get_layout_class(

checkpoint/orbax/checkpoint/experimental/v1/_src/loading/loading.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414

1515
"""Defines free-function interface for loading."""
1616

17-
import asyncio
1817
import functools
1918
import time
2019
from typing import Any, Awaitable, Protocol
2120

2221
from absl import logging
22+
from orbax.checkpoint._src import asyncio_utils
2323
from orbax.checkpoint._src.logging import event_tracking
2424
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
2525
import orbax.checkpoint.experimental.v1._src.handlers.global_registration # pylint: disable=unused-import
@@ -115,7 +115,7 @@ def load_pytree(
115115
logging.info('Loading checkpoint from %s.', path)
116116
ctx = context_lib.get_context()
117117
path = ctx.file_options.path_class(path)
118-
layout = asyncio.run(
118+
layout = asyncio_utils.run_sync(
119119
layout_registry.get_checkpoint_layout_pytree(
120120
path, ctx.checkpoint_layout, checkpointable_name
121121
)
@@ -204,7 +204,7 @@ def load_checkpointables(
204204
logging.info('Loading checkpoint from %s.', path)
205205
ctx = context_lib.get_context()
206206
path = ctx.file_options.path_class(path)
207-
layout = asyncio.run(
207+
layout = asyncio_utils.run_sync(
208208
layout_registry.get_checkpoint_layout(path, ctx.checkpoint_layout)
209209
)
210210

@@ -264,7 +264,7 @@ async def _load() -> Any:
264264
)
265265
return result
266266

267-
result = asyncio.run(_load())
267+
result = asyncio_utils.run_sync(_load())
268268

269269
event_tracking.record_read_event(path)
270270

checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414

1515
"""Functions for loading metadata from a checkpoint."""
1616

17-
import asyncio
1817
from typing import Any
1918

19+
from orbax.checkpoint._src import asyncio_utils
2020
from orbax.checkpoint.experimental.v1 import errors
2121
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
2222
import orbax.checkpoint.experimental.v1._src.handlers.global_registration # pylint: disable=unused-import
@@ -88,7 +88,7 @@ def _get_abstract_array(arr):
8888
ctx = context_lib.get_context()
8989
path = ctx.file_options.path_class(path)
9090

91-
layout = asyncio.run(
91+
layout = asyncio_utils.run_sync(
9292
layout_registry.get_checkpoint_layout_pytree(
9393
path, ctx.checkpoint_layout, checkpointable_name
9494
)
@@ -140,7 +140,7 @@ def checkpointables_metadata(
140140
"""
141141
ctx = context_lib.get_context()
142142
path = ctx.file_options.path_class(path)
143-
layout = asyncio.run(
143+
layout = asyncio_utils.run_sync(
144144
layout_registry.get_checkpoint_layout(path, ctx.checkpoint_layout)
145145
)
146146
return _checkpointables_metadata_impl(layout, path)
@@ -157,4 +157,4 @@ async def _load_metadata() -> (
157157
):
158158
return await layout.metadata(path)
159159

160-
return asyncio.run(_load_metadata())
160+
return asyncio_utils.run_sync(_load_metadata())

checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
from __future__ import annotations
1818

19-
import asyncio
2019
import hashlib
2120
import time
2221
from typing import Any, Awaitable, Iterable
@@ -25,6 +24,7 @@
2524
from absl import logging
2625
import jax
2726
import numpy as np
27+
from orbax.checkpoint._src import asyncio_utils
2828
from orbax.checkpoint._src.futures import future
2929
from orbax.checkpoint._src.logging import event_tracking
3030
from orbax.checkpoint._src.metadata import step_metadata_serialization
@@ -380,7 +380,7 @@ def save_checkpointables_impl(
380380
event_tracking.record_save_start(path, async_origin=async_origin)
381381
# Ensure the operation ID is incremented as soon as possible. This must be
382382
# done uniquely for each save operation.
383-
asyncio.run(context_lib.synchronize_next_operation_id())
383+
asyncio_utils.run_sync(context_lib.synchronize_next_operation_id())
384384
context = context_lib.get_context()
385385

386386
path = context.file_options.path_class(path)
@@ -397,7 +397,7 @@ def save_checkpointables_impl(
397397
subdirectories=subdirectories,
398398
use_snapshot=path_exists,
399399
)
400-
background_awaitable = asyncio.run(
400+
background_awaitable = asyncio_utils.run_sync(
401401
_run_blocking_save(
402402
temporary_path,
403403
checkpointables,

checkpoint/orbax/checkpoint/experimental/v1/_src/testing/handler_utils.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
from __future__ import annotations
1818

19-
import asyncio
2019
import dataclasses
2120
import json
2221
from typing import Any, Awaitable, Generic, Type, TypeVar
@@ -25,6 +24,7 @@
2524
from etils import epath
2625
from orbax.checkpoint import checkpoint_args as v0_args
2726
from orbax.checkpoint import handlers as v0_handlers
27+
from orbax.checkpoint._src import asyncio_utils
2828
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
2929
from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types
3030
from orbax.checkpoint.experimental.v1._src.path import types as path_types
@@ -59,24 +59,26 @@ def __init__(
5959

6060
def save(self, directory: Path, checkpointable: T):
6161
path = path_test_utils.PathAwaitingCreationWrapper(directory)
62-
awaitable = asyncio.run(self._handler.save(path, checkpointable))
63-
return asyncio.run(_run_awaitable(awaitable))
62+
awaitable = asyncio_utils.run_sync(self._handler.save(path, checkpointable))
63+
return asyncio_utils.run_sync(_run_awaitable(awaitable))
6464

6565
def save_async(self, directory: Path, checkpointable: T):
6666
path = path_test_utils.PathAwaitingCreationWrapper(directory)
67-
return asyncio.run(self._handler.save(path, checkpointable))
67+
return asyncio_utils.run_sync(self._handler.save(path, checkpointable))
6868

6969
def load(self, path: Path, abstract_checkpointable: AbstractT | None = None):
7070
awaitable = self.load_async(path, abstract_checkpointable)
71-
return asyncio.run(_run_awaitable(awaitable))
71+
return asyncio_utils.run_sync(_run_awaitable(awaitable))
7272

7373
def load_async(
7474
self, path: Path, abstract_checkpointable: AbstractT | None = None
7575
):
76-
return asyncio.run(self._handler.load(path, abstract_checkpointable))
76+
return asyncio_utils.run_sync(
77+
self._handler.load(path, abstract_checkpointable)
78+
)
7779

7880
def metadata(self, path: Path) -> AbstractT:
79-
return asyncio.run(self._handler.metadata(path))
81+
return asyncio_utils.run_sync(self._handler.metadata(path))
8082

8183
def is_handleable(self, checkpointable: Any) -> bool:
8284
return self._handler.is_handleable(checkpointable)

0 commit comments

Comments
 (0)