@@ -951,7 +951,7 @@ def _record_saved_duration(checkpoint_start_time: float):
951951 # Note: for the very first checkpoint, this is the interval between program
952952 # init and the current checkpoint start time.
953953 duration_since_last_checkpoint = checkpoint_start_time - _LAST_CHECKPOINT_WRITE_TIME
954- if jax . version . __version_info__ > ( 0 , 3 , 25 ) :
954+ if monitoring is not None :
955955 monitoring .record_event_duration_secs (
956956 '/jax/checkpoint/write/duration_since_last_checkpoint_secs' ,
957957 duration_since_last_checkpoint )
@@ -1151,7 +1151,7 @@ def save_main_ckpt_task():
11511151 else :
11521152 save_main_ckpt_task ()
11531153 end_time = time .time ()
1154- if jax . version . __version_info__ > ( 0 , 3 , 25 ) :
1154+ if monitoring is not None :
11551155 monitoring .record_event_duration_secs (_WRITE_CHECKPOINT_EVENT ,
11561156 end_time - start_time )
11571157 return ckpt_path
@@ -1281,7 +1281,7 @@ def save_main_ckpt_task():
12811281 else :
12821282 save_main_ckpt_task ()
12831283 end_time = time .time ()
1284- if jax . version . __version_info__ > ( 0 , 3 , 25 ) :
1284+ if monitoring is not None :
12851285 monitoring .record_event_duration_secs (_WRITE_CHECKPOINT_EVENT ,
12861286 end_time - start_time )
12871287
@@ -1390,7 +1390,7 @@ def save_main_ckpt_task():
13901390 keep , overwrite , keep_every_n_steps , start_time , async_manager )
13911391
13921392 end_time = time .time ()
1393- if jax . version . __version_info__ > ( 0 , 3 , 25 ) :
1393+ if monitoring is not None :
13941394 monitoring .record_event_duration_secs (_WRITE_CHECKPOINT_EVENT ,
13951395 end_time - start_time )
13961396 return ckpt_path
@@ -1553,7 +1553,7 @@ def read_chunk(i):
15531553 restored_checkpoint = from_state_dict (target , state_dict )
15541554
15551555 end_time = time .time ()
1556- if jax . version . __version_info__ > ( 0 , 3 , 25 ) :
1556+ if monitoring is not None :
15571557 monitoring .record_event_duration_secs (_READ_CHECKPOINT_EVENT , end_time - start_time )
15581558
15591559 return restored_checkpoint
@@ -1616,7 +1616,7 @@ def read_chunk(i):
16161616
16171617 state_dict = msgpack_restore (checkpoint_contents )
16181618 end_time = time .time ()
1619- if jax . version . __version_info__ > ( 0 , 3 , 25 ) :
1619+ if monitoring is not None :
16201620 monitoring .record_event_duration_secs (_READ_CHECKPOINT_EVENT , end_time - start_time )
16211621
16221622 return state_dict
0 commit comments