File tree 3 files changed +24
-0
lines changed
checkpoint/orbax/checkpoint/_src
3 files changed +24
-0
lines changed Original file line number Diff line number Diff line change 35
35
from orbax .checkpoint ._src .path import async_utils
36
36
from orbax .checkpoint ._src .path import atomicity
37
37
from orbax .checkpoint ._src .path import atomicity_types
38
+ from orbax .checkpoint ._src .path import utils as path_utils
38
39
39
40
40
41
@@ -423,6 +424,10 @@ async def _save(
423
424
directory = tmpdir .get_final ()
424
425
self .synchronize_next_awaitable_signal_operation_id ()
425
426
427
+ jax .monitoring .record_event (
428
+ '/jax/orbax/write/async/storage_type' ,
429
+ storage_type = path_utils .get_storage_type (directory ),
430
+ )
426
431
jax .monitoring .record_event ('/jax/orbax/write/async/start' )
427
432
logging .info (
428
433
'[process=%s] Started async saving checkpoint to %s.' ,
Original file line number Diff line number Diff line change 35
35
from orbax .checkpoint ._src .path import atomicity
36
36
from orbax .checkpoint ._src .path import atomicity_defaults
37
37
from orbax .checkpoint ._src .path import atomicity_types
38
+ from orbax .checkpoint ._src .path import utils as path_utils
38
39
from typing_extensions import Self # for Python version < 3.11
39
40
40
41
@@ -212,6 +213,10 @@ def save(
212
213
checkpoint_start_time = time .time ()
213
214
directory = epath .Path (directory )
214
215
216
+ jax .monitoring .record_event (
217
+ '/jax/orbax/write/storage_type' ,
218
+ storage_type = path_utils .get_storage_type (directory ),
219
+ )
215
220
jax .monitoring .record_event ('/jax/orbax/write/start' )
216
221
logging .info (
217
222
'[process=%s] Started saving checkpoint to %s.' ,
Original file line number Diff line number Diff line change 22
22
23
23
24
24
25
+ _GCS_PATH_PREFIX = ('gs://' ,)
26
+
27
+
28
+ def is_gcs_path (path : epath .Path ) -> bool :
29
+ return path .as_posix ().startswith (_GCS_PATH_PREFIX )
30
+
31
+
32
+ def get_storage_type (path : epath .Path ) -> str :
33
+ if is_gcs_path (path ):
34
+ return 'gcs'
35
+ else :
36
+ return 'local'
37
+
38
+
25
39
class Timer (object ):
26
40
"""A simple timer to measure the time it takes to run a function."""
27
41
You can’t perform that action at this time.
0 commit comments