Skip to content

Commit 796be51

Browse files
committed
refactor(core): abstract torch.distributed APIs in CheckpointLoader
Add injectable callable parameters to DefaultMLFlashpointCheckpointLoader for get_rank, get_node_local_rank, broadcast_object_list, all_gather_object, and get_world_size, mirroring the pattern already used in DefaultMLFlashpointCheckpointSaver. All parameters default to the corresponding torch.distributed functions, preserving backwards compatibility. This makes the loader easier to test via dependency injection and allows swapping implementations without subclassing or monkey-patching torch.distributed. Closes #30
1 parent 6f36c9c commit 796be51

File tree

4 files changed

+234
-158
lines changed

4 files changed

+234
-158
lines changed

src/ml_flashpoint/adapter/nemo/nemo_checkpoint_loader.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import os
1616
from pathlib import Path
17-
from typing import List, Set
17+
from typing import Callable, List, Set
1818

1919
from typing_extensions import override
2020

@@ -33,6 +33,12 @@ def __init__(
3333
self,
3434
checkpoint_object_manager: CheckpointObjectManager,
3535
replication_manager: ReplicationManager,
36+
*,
37+
global_rank_getter: Callable[[], int],
38+
local_rank_getter: Callable[[], int],
39+
broadcast_object_list_func: Callable[..., None],
40+
all_gather_object_func: Callable[..., None],
41+
world_size_getter: Callable[[], int],
3642
recover_context: bool = False,
3743
):
3844
"""Initializes the NeMoMLFlashpointCheckpointLoader.
@@ -42,9 +48,24 @@ def __init__(
4248
reading data.
4349
replication_manager: The replication manager to use for retrieving
4450
missing checkpoint objects from peer nodes.
51+
global_rank_getter: A callable that returns the global rank.
52+
local_rank_getter: A callable that returns the node-local rank.
53+
broadcast_object_list_func: A callable with the same signature as
54+
``torch.distributed.broadcast_object_list``.
55+
all_gather_object_func: A callable with the same signature as
56+
``torch.distributed.all_gather_object``.
57+
world_size_getter: A callable that returns the world size.
4558
recover_context: Whether to recover the context directory if missing.
4659
"""
47-
super().__init__(checkpoint_object_manager, replication_manager)
60+
super().__init__(
61+
checkpoint_object_manager,
62+
replication_manager,
63+
global_rank_getter=global_rank_getter,
64+
local_rank_getter=local_rank_getter,
65+
broadcast_object_list_func=broadcast_object_list_func,
66+
all_gather_object_func=all_gather_object_func,
67+
world_size_getter=world_size_getter,
68+
)
4869
self._recover_context = recover_context
4970

5071
@override

src/ml_flashpoint/core/checkpoint_loader.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import struct
2323
from collections import defaultdict
2424
from pathlib import Path
25-
from typing import IO, List, Optional, Set, Tuple, TypeVar, cast
25+
from typing import IO, Callable, List, Optional, Set, Tuple, TypeVar, cast
2626

2727
import torch
2828
import torch.distributed as dist
@@ -128,6 +128,12 @@ def __init__(
128128
self,
129129
checkpoint_object_manager: CheckpointObjectManager,
130130
replication_manager: ReplicationManager,
131+
*,
132+
global_rank_getter: Callable[[], int],
133+
local_rank_getter: Callable[[], int],
134+
broadcast_object_list_func: Callable[..., None],
135+
all_gather_object_func: Callable[..., None],
136+
world_size_getter: Callable[[], int],
131137
):
132138
"""Initializes the DefaultMLFlashpointCheckpointLoader.
133139
@@ -136,9 +142,21 @@ def __init__(
136142
reading data.
137143
replication_manager: The replication manager to use for retrieving
138144
missing checkpoint objects from peer nodes.
145+
global_rank_getter: A callable that returns the global rank.
146+
local_rank_getter: A callable that returns the node-local rank.
147+
broadcast_object_list_func: A callable with the same signature as
148+
``torch.distributed.broadcast_object_list``.
149+
all_gather_object_func: A callable with the same signature as
150+
``torch.distributed.all_gather_object``.
151+
world_size_getter: A callable that returns the world size.
139152
"""
140153
self._checkpoint_object_manager = checkpoint_object_manager
141154
self._replication_manager = replication_manager
155+
self._global_rank_getter = global_rank_getter
156+
self._local_rank_getter = local_rank_getter
157+
self._broadcast_object_list_func = broadcast_object_list_func
158+
self._all_gather_object_func = all_gather_object_func
159+
self._world_size_getter = world_size_getter
142160
# Cache for available objects: CheckpointContainerId -> dict[object_path, list[rank]]
143161
self._available_objects_cache: dict[CheckpointContainerId, dict[str, List[int]]] = {}
144162

@@ -337,8 +355,7 @@ def get_latest_complete_checkpoint(
337355
else continue to the next candidate checkpoint
338356
- return the checkpoint container id of the latest complete checkpoint
339357
"""
340-
# TODO: use global_rank_getter and local_rank_getter.
341-
rank = dist.get_rank()
358+
rank = self._global_rank_getter()
342359
_LOGGER.debug(
343360
"Rank %s: Getting latest complete checkpoint for '%s'",
344361
rank,
@@ -382,7 +399,7 @@ def get_latest_complete_checkpoint(
382399
retrieval_plan = self._compute_retrieval_plan(checkpoint, available_objects_by_rank)
383400
# Broadcast the retrieval plan to all ranks.
384401
plan_container = [retrieval_plan]
385-
dist.broadcast_object_list(plan_container, src=planner_rank)
402+
self._broadcast_object_list_func(plan_container, src=planner_rank)
386403
retrieval_plan = plan_container[0]
387404

388405
if retrieval_plan is None:
@@ -451,7 +468,7 @@ def _compute_retrieval_plan(
451468

452469
objects_needed_by_local_rank_0.update(self._get_extra_needed_objects(checkpoint, available_objects_by_rank))
453470

454-
world_size = dist.get_world_size()
471+
world_size = self._world_size_getter()
455472
num_nodes = get_num_of_nodes()
456473
ranks_per_node = world_size // num_nodes
457474

@@ -507,8 +524,8 @@ def get_candidate_checkpoints(
507524

508525
# Scan locally only on the first rank of each node
509526
base_path = Path(checkpoint_base_container.data)
510-
rank = dist.get_rank()
511-
local_rank = dist.get_node_local_rank()
527+
rank = self._global_rank_getter()
528+
local_rank = self._local_rank_getter()
512529

513530
local_candidate_ckpt_ids = []
514531

@@ -532,8 +549,8 @@ def get_candidate_checkpoints(
532549
else:
533550
_LOGGER.debug("Rank %s: Base path '%s' is not a directory or does not exist.", rank, base_path)
534551

535-
all_checkpoint_container_path_lists = [None for _ in range(dist.get_world_size())]
536-
dist.all_gather_object(all_checkpoint_container_path_lists, local_candidate_ckpt_ids)
552+
all_checkpoint_container_path_lists = [None for _ in range(self._world_size_getter())]
553+
self._all_gather_object_func(all_checkpoint_container_path_lists, local_candidate_ckpt_ids)
537554
_LOGGER.debug(
538555
"Rank %s: Gathered checkpoint container paths from all ranks: '%s'",
539556
rank,
@@ -589,8 +606,8 @@ def get_checkpoint_objects_by_rank(
589606

590607
local_objects.extend(self._get_extra_local_objects(container_path))
591608

592-
all_objects_by_rank_paths = [None for _ in range(dist.get_world_size())]
593-
dist.all_gather_object(all_objects_by_rank_paths, local_objects)
609+
all_objects_by_rank_paths = [None for _ in range(self._world_size_getter())]
610+
self._all_gather_object_func(all_objects_by_rank_paths, local_objects)
594611

595612
result = {}
596613
object_locations = defaultdict(list)
@@ -620,7 +637,7 @@ def retrieve_checkpoint(
620637
If empty for this rank, no retrieval is needed.
621638
"""
622639

623-
rank = dist.get_rank()
640+
rank = self._global_rank_getter()
624641
all_success = True
625642

626643
# Only proceed with retrieval if we have items to retrieve
@@ -656,8 +673,8 @@ def retrieve_checkpoint(
656673

657674
# Gather success status from all ranks
658675
_LOGGER.debug("Gathering success status from all ranks")
659-
all_success_list = [None for _ in range(dist.get_world_size())]
660-
dist.all_gather_object(all_success_list, all_success)
676+
all_success_list = [None for _ in range(self._world_size_getter())]
677+
self._all_gather_object_func(all_success_list, all_success)
661678
_LOGGER.debug("All success list: '%s'", all_success_list)
662679
return all(all_success_list)
663680

tests/adapter/nemo/test_nemo_checkpoint_loader.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,32 @@
2525

2626
class TestNeMoCheckpointLoaderContext:
2727
@pytest.fixture
28-
def loader(self):
28+
def _setup_mocks(self, mocker):
29+
self.mock_global_rank = MagicMock(return_value=0)
30+
self.mock_local_rank = MagicMock(return_value=0)
31+
self.mock_world_size = MagicMock(return_value=1)
32+
self.mock_all_gather = MagicMock()
33+
self.mock_broadcast = MagicMock()
34+
35+
@pytest.fixture
36+
def loader(self, mocker, _setup_mocks):
2937
ckpt_manager = CheckpointObjectManager()
3038
repl_manager = MagicMock(spec=ReplicationManager)
3139
return NeMoMLFlashpointCheckpointLoader(
32-
checkpoint_object_manager=ckpt_manager, replication_manager=repl_manager, recover_context=True
40+
checkpoint_object_manager=ckpt_manager,
41+
replication_manager=repl_manager,
42+
global_rank_getter=self.mock_global_rank,
43+
local_rank_getter=self.mock_local_rank,
44+
broadcast_object_list_func=self.mock_broadcast,
45+
all_gather_object_func=self.mock_all_gather,
46+
world_size_getter=self.mock_world_size,
47+
recover_context=True,
3348
)
3449

3550
def test_get_checkpoint_objects_by_rank_finds_context(self, loader, mocker):
3651
"""Test that get_checkpoint_objects_by_rank finds files in context/ dir when recover_context=True."""
37-
mocker.patch("torch.distributed.get_world_size", return_value=1)
38-
mocker.patch(
39-
"torch.distributed.all_gather_object",
40-
side_effect=lambda obj_list, local_obj: obj_list.__setitem__(0, local_obj),
41-
)
42-
# Mock get_node_local_rank to avoid external dependency issues if called
43-
mocker.patch("torch.distributed.get_node_local_rank", return_value=0)
52+
self.mock_world_size.return_value = 1
53+
self.mock_all_gather.side_effect = lambda obj_list, local_obj: obj_list.__setitem__(0, local_obj)
4454

4555
container_path = "/tmp/ckpt/step-1"
4656
container_id = CheckpointContainerId(container_path)
@@ -109,9 +119,9 @@ def test_compute_retrieval_plan_includes_context_optimized(self, loader, mocker)
109119
mock_metadata.storage_data = {}
110120
mocker.patch.object(loader, "read_metadata", return_value=mock_metadata)
111121

112-
mocker.patch("torch.distributed.get_world_size", return_value=4)
122+
self.mock_world_size.return_value = 4
113123
mocker.patch("ml_flashpoint.core.checkpoint_loader.get_num_of_nodes", return_value=2)
114-
mocker.patch("torch.distributed.get_rank", return_value=0)
124+
self.mock_global_rank.return_value = 0
115125

116126
ctx_file = str(Path(checkpoint.data) / "context" / "file1.txt")
117127
nested_ctx_file = str(Path(checkpoint.data) / "context" / "subdir" / "file3.txt")

0 commit comments

Comments
 (0)