Skip to content

Commit 99cc301

Browse files
committed
fix: logger for update_p2p without rank 0
1 parent 3e58010 commit 99cc301

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

checkpoint_engine/ps.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,7 @@ def __init__(
678678
self._all_hosts = []
679679
self._global_device_uuids: list[str] = []
680680
self._mem_fraction = mem_fraction or 0.9
681+
self._logger_rank = 0
681682

682683
assert self._rank is not None and self._rank >= 0, self._rank
683684
assert self._world_size and self._world_size > 0, self._world_size
@@ -706,8 +707,8 @@ def __init__(
706707
torch.cuda.set_device(device_index)
707708
self._device_uuid = _get_physical_gpu_id(device_index)
708709

709-
def _logger_rank0(self, msg: str):
710-
if self._local_rank == 0:
710+
def _logger_once(self, msg: str):
711+
if self._local_rank == self._logger_rank:
711712
logger.info(msg)
712713

713714
def get_metas(self) -> dict[int, MemoryBufferMetaList]:
@@ -871,10 +872,12 @@ def update(
871872
try:
872873
# if both ranks is None or [], it will use fully broadcast to update to all ranks
873874
if not ranks:
875+
self._logger_rank = 0
874876
if self._auto_pg and not dist.is_initialized():
875877
self.init_process_group()
876878
self._update_per_bucket(checkpoint_name, req_func)
877879
else:
880+
self._logger_rank = ranks[0]
878881
if not self._auto_pg and self._rank not in ranks:
879882
return
880883
if self._auto_pg:
@@ -936,15 +939,15 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int,
936939
max_tensor_bytes = max(max_tensor_bytes, _align_size(meta.dtype, meta.shape))
937940
free_bytes_divided_3 = free_bytes // (3 * _ALIGN_SIZE) * _ALIGN_SIZE
938941
if max_tensor_bytes <= free_bytes_divided_3 and not disable_h2d_buffer:
939-
self._logger_rank0(f"[rank{self._rank}] use h2d buffer")
942+
self._logger_once(f"[rank{self._rank}] use h2d buffer")
940943
# using h2d_buffer can make all ranks' h2d parallel execution
941944
# the cost is that we need to allocate extra h2d_buffer's GPU memory
942945
free_bytes = free_bytes_divided_3
943946
else:
944947
# if the memory is not enough, it will fallback to disable_h2d_buffer mode,
945948
# at this time, the bandwidth will be limited by the h2d of a single machine,
946949
# but we can save GPU memory
947-
self._logger_rank0(
950+
self._logger_once(
948951
f"[rank{self._rank}] disable h2d buffer when max_tensor_bytes {max_tensor_bytes} is larger than free_bytes {free_bytes} // 3"
949952
)
950953
free_bytes = free_bytes // (2 * _ALIGN_SIZE) * _ALIGN_SIZE
@@ -1074,7 +1077,7 @@ def _update_per_bucket_p2p(
10741077
req_thread.start()
10751078
socket.send_pyobj(handle)
10761079
for gidx, (owner_rank, bucket) in enumerate(buckets):
1077-
self._logger_rank0(
1080+
self._logger_once(
10781081
f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} owner_rank {owner_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
10791082
)
10801083
_buffer = buffer[gidx % 2 * bucket_size : gidx % 2 * bucket_size + bucket.size]
@@ -1178,7 +1181,7 @@ def _update_per_bucket(
11781181
torch.cuda.memory_allocated() / 1024 / 1024,
11791182
torch.cuda.memory_reserved() / 1024 / 1024,
11801183
)
1181-
self._logger_rank0(
1184+
self._logger_once(
11821185
f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} owner_rank {owner_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
11831186
f"Current CUDA allocated {alloc:.2f} MB, "
11841187
f"reserved {reserved:.2f} MB."

0 commit comments

Comments
 (0)