@@ -678,6 +678,7 @@ def __init__(
678678 self ._all_hosts = []
679679 self ._global_device_uuids : list [str ] = []
680680 self ._mem_fraction = mem_fraction or 0.9
681+ self ._logger_rank = 0
681682
682683 assert self ._rank is not None and self ._rank >= 0 , self ._rank
683684 assert self ._world_size and self ._world_size > 0 , self ._world_size
@@ -706,8 +707,8 @@ def __init__(
706707 torch .cuda .set_device (device_index )
707708 self ._device_uuid = _get_physical_gpu_id (device_index )
708709
709- def _logger_rank0 (self , msg : str ):
710- if self ._local_rank == 0 :
710+ def _logger_once (self , msg : str ):
711+ if self ._local_rank == self . _logger_rank :
711712 logger .info (msg )
712713
713714 def get_metas (self ) -> dict [int , MemoryBufferMetaList ]:
@@ -871,10 +872,12 @@ def update(
871872 try :
872873 # if both ranks is None or [], it will use fully broadcast to update to all ranks
873874 if not ranks :
875+ self ._logger_rank = 0
874876 if self ._auto_pg and not dist .is_initialized ():
875877 self .init_process_group ()
876878 self ._update_per_bucket (checkpoint_name , req_func )
877879 else :
880+ self ._logger_rank = ranks [0 ]
878881 if not self ._auto_pg and self ._rank not in ranks :
879882 return
880883 if self ._auto_pg :
@@ -936,15 +939,15 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int,
936939 max_tensor_bytes = max (max_tensor_bytes , _align_size (meta .dtype , meta .shape ))
937940 free_bytes_divided_3 = free_bytes // (3 * _ALIGN_SIZE ) * _ALIGN_SIZE
938941 if max_tensor_bytes <= free_bytes_divided_3 and not disable_h2d_buffer :
939- self ._logger_rank0 (f"[rank{ self ._rank } ] use h2d buffer" )
942+ self ._logger_once (f"[rank{ self ._rank } ] use h2d buffer" )
940943 # using h2d_buffer can make all ranks' h2d parallel execution
941944 # the cost is that we need to allocate extra h2d_buffer's GPU memory
942945 free_bytes = free_bytes_divided_3
943946 else :
944947 # if the memory is not enough, it will fallback to disable_h2d_buffer mode,
945948 # at this time, the bandwidth will be limited by the h2d of a single machine,
946949 # but we can save GPU memory
947- self ._logger_rank0 (
950+ self ._logger_once (
948951 f"[rank{ self ._rank } ] disable h2d buffer when max_tensor_bytes { max_tensor_bytes } is larger than free_bytes { free_bytes } // 3"
949952 )
950953 free_bytes = free_bytes // (2 * _ALIGN_SIZE ) * _ALIGN_SIZE
@@ -1074,7 +1077,7 @@ def _update_per_bucket_p2p(
10741077 req_thread .start ()
10751078 socket .send_pyobj (handle )
10761079 for gidx , (owner_rank , bucket ) in enumerate (buckets ):
1077- self ._logger_rank0 (
1080+ self ._logger_once (
10781081 f"[rank{ self ._rank } ] begin to update bucket { gidx + 1 } /{ len (buckets )} owner_rank { owner_rank } in checkpoint { checkpoint_name } , bucket_size: { bucket .size / 1024 / 1024 :.2f} MiB, length: { len (bucket .items )} . "
10791082 )
10801083 _buffer = buffer [gidx % 2 * bucket_size : gidx % 2 * bucket_size + bucket .size ]
@@ -1178,7 +1181,7 @@ def _update_per_bucket(
11781181 torch .cuda .memory_allocated () / 1024 / 1024 ,
11791182 torch .cuda .memory_reserved () / 1024 / 1024 ,
11801183 )
1181- self ._logger_rank0 (
1184+ self ._logger_once (
11821185 f"[rank{ self ._rank } ] begin to update bucket { gidx + 1 } /{ len (buckets )} owner_rank { owner_rank } in checkpoint { checkpoint_name } , bucket_size: { bucket .size / 1024 / 1024 :.2f} MiB, length: { len (bucket .items )} . "
11831186 f"Current CUDA allocated { alloc :.2f} MB, "
11841187 f"reserved { reserved :.2f} MB."
0 commit comments