Skip to content

Commit 45b5677

Browse files
author
Orbax Authors
committed
Internal
PiperOrigin-RevId: 733831700
1 parent c1a21a6 commit 45b5677

File tree

3 files changed

+24
-0
lines changed

3 files changed

+24
-0
lines changed

checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py

+5
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from orbax.checkpoint._src.path import async_utils
3636
from orbax.checkpoint._src.path import atomicity
3737
from orbax.checkpoint._src.path import atomicity_types
38+
from orbax.checkpoint._src.path import utils as path_utils
3839

3940

4041

@@ -423,6 +424,10 @@ async def _save(
423424
directory = tmpdir.get_final()
424425
self.synchronize_next_awaitable_signal_operation_id()
425426

427+
jax.monitoring.record_event(
428+
'/jax/orbax/write/async/storage_type',
429+
storage_type=path_utils.get_storage_type(directory),
430+
)
426431
jax.monitoring.record_event('/jax/orbax/write/async/start')
427432
logging.info(
428433
'[process=%s] Started async saving checkpoint to %s.',

checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer.py

+5
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from orbax.checkpoint._src.path import atomicity
3636
from orbax.checkpoint._src.path import atomicity_defaults
3737
from orbax.checkpoint._src.path import atomicity_types
38+
from orbax.checkpoint._src.path import utils as path_utils
3839
from typing_extensions import Self # for Python version < 3.11
3940

4041

@@ -212,6 +213,10 @@ def save(
212213
checkpoint_start_time = time.time()
213214
directory = epath.Path(directory)
214215

216+
jax.monitoring.record_event(
217+
'/jax/orbax/write/storage_type',
218+
storage_type=path_utils.get_storage_type(directory),
219+
)
215220
jax.monitoring.record_event('/jax/orbax/write/start')
216221
logging.info(
217222
'[process=%s] Started saving checkpoint to %s.',

checkpoint/orbax/checkpoint/_src/path/utils.py

+14
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,20 @@
2222

2323

2424

25+
_GCS_PATH_PREFIX = ('gs://',)
26+
27+
28+
def is_gcs_path(path: epath.Path) -> bool:
29+
return path.as_posix().startswith(_GCS_PATH_PREFIX)
30+
31+
32+
def get_storage_type(path: epath.Path) -> str:
33+
if is_gcs_path(path):
34+
return 'gcs'
35+
else:
36+
return 'local'
37+
38+
2539
class Timer(object):
2640
"""A simple timer to measure the time it takes to run a function."""
2741

0 commit comments

Comments
 (0)