diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 58b5a61..ae568c8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,7 +21,7 @@ repos: rev: v0.12.2 hooks: - id: ruff - args: [--fix, --exit-non-zero-on-fix] + args: [--fix, --exit-non-zero-on-fix, --ignore, S603,] - id: ruff-format - repo: https://github.com/codespell-project/codespell rev: v2.4.1 diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index 1493a69..723bed0 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -678,6 +678,7 @@ def __init__( self._all_hosts = [] self._global_device_uuids: list[str] = [] self._mem_fraction = mem_fraction or 0.9 + self._logger_rank = 0 assert self._rank is not None and self._rank >= 0, self._rank assert self._world_size and self._world_size > 0, self._world_size @@ -706,8 +707,8 @@ def __init__( torch.cuda.set_device(device_index) self._device_uuid = _get_physical_gpu_id(device_index) - def _logger_rank0(self, msg: str): - if self._local_rank == 0: + def _logger_once(self, msg: str): + if self._local_rank == self._logger_rank: logger.info(msg) def get_metas(self) -> dict[int, MemoryBufferMetaList]: @@ -871,10 +872,12 @@ def update( try: # if both ranks is None or [], it will use fully broadcast to update to all ranks if not ranks: + self._logger_rank = 0 if self._auto_pg and not dist.is_initialized(): self.init_process_group() self._update_per_bucket(checkpoint_name, req_func) else: + self._logger_rank = ranks[0] if not self._auto_pg and self._rank not in ranks: return if self._auto_pg: @@ -936,7 +939,7 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int, max_tensor_bytes = max(max_tensor_bytes, _align_size(meta.dtype, meta.shape)) free_bytes_divided_3 = free_bytes // (3 * _ALIGN_SIZE) * _ALIGN_SIZE if max_tensor_bytes <= free_bytes_divided_3 and not disable_h2d_buffer: - self._logger_rank0(f"[rank{self._rank}] use h2d buffer") + self._logger_once(f"[rank{self._rank}] use h2d buffer") # using h2d_buffer can make all ranks' h2d parallel execution # the cost is that we need to allocate extra h2d_buffer's GPU memory free_bytes = free_bytes_divided_3 @@ -944,7 +947,7 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int, # if the memory is not enough, it will fallback to disable_h2d_buffer mode, # at this time, the bandwidth will be limited by the h2d of a single machine, # but we can save GPU memory - self._logger_rank0( + self._logger_once( f"[rank{self._rank}] disable h2d buffer when max_tensor_bytes {max_tensor_bytes} is larger than free_bytes {free_bytes} // 3" ) free_bytes = free_bytes // (2 * _ALIGN_SIZE) * _ALIGN_SIZE @@ -1074,7 +1077,7 @@ def _update_per_bucket_p2p( req_thread.start() socket.send_pyobj(handle) for gidx, (owner_rank, bucket) in enumerate(buckets): - self._logger_rank0( + self._logger_once( f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} owner_rank {owner_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. " ) _buffer = buffer[gidx % 2 * bucket_size : gidx % 2 * bucket_size + bucket.size] @@ -1178,7 +1181,7 @@ def _update_per_bucket( torch.cuda.memory_allocated() / 1024 / 1024, torch.cuda.memory_reserved() / 1024 / 1024, ) - self._logger_rank0( + self._logger_once( f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} owner_rank {owner_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. " f"Current CUDA allocated {alloc:.2f} MB, " f"reserved {reserved:.2f} MB." diff --git a/tests/test_update.py b/tests/test_update.py index 2c99256..f426232 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -1,7 +1,9 @@ import os import random +import subprocess import time +import pytest import torch import zmq from torch.multiprocessing import Queue, get_context @@ -63,9 +65,8 @@ def check_weights(names_to_check: dict[str, bool], socket_paths: list[tuple[str, check_weights(names_to_check, socket_paths) -def run(): +def run_with_specified_ranks(ranks: list[int]): rank = int(os.getenv("RANK")) - world_size = int(os.getenv("WORLD_SIZE")) ctx = get_context("spawn") queue = ctx.Queue() _device_uuid = _get_physical_gpu_id(rank) @@ -76,15 +77,53 @@ def run(): proc.start() ps.register_checkpoint(checkpoint_name, named_tensors=named_tensors) ps.gather_metas(checkpoint_name) - ranks_list = [[], list(range(world_size // 2)), [], list(range(world_size))] - for ranks in ranks_list: - ps.update(checkpoint_name, queue.put, ranks=ranks) - # sleep 3s to wait process group is destroyed - time.sleep(3) + ps.update(checkpoint_name, queue.put, ranks=ranks) + time.sleep(5) ps.unregister_checkpoint(checkpoint_name) queue.put(None) proc.join() +def run(): + world_size = int(os.getenv("WORLD_SIZE")) + random.seed(42) + ranklist = [ + list(random.sample(range(world_size), k=num_ranks)) for num_ranks in range(world_size + 1) + ] + for ranks in ranklist: + run_with_specified_ranks(ranks) + + +@pytest.mark.gpu +def test_update(): + world_size = torch.cuda.device_count() + assert world_size >= 2, "This test requires at least 2 GPUs." + + master_addr = "localhost" + master_port = random.randint(20000, 30000) + + cmd = [ + "torchrun", + "--nproc_per_node", + str(world_size), + "--master_addr", + master_addr, + "--master_port", + str(master_port), + "tests/test_update.py", + ] + + result = subprocess.run( + cmd, + capture_output=False, + text=True, + cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))), + shell=False, + check=False, + ) + + assert result.returncode == 0 + + if __name__ == "__main__": run()