Skip to content

Commit 6bf5a3b

Browse files
committed
Resolve comments.
1 parent f988d70 commit 6bf5a3b

File tree

6 files changed

+55
-26
lines changed

6 files changed

+55
-26
lines changed

src/ml_flashpoint/adapter/nemo/checkpoint_loader.py renamed to src/ml_flashpoint/adapter/nemo/nemo_checkpoint_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,14 @@ def __init__(
4848
self._recover_context = recover_context
4949

5050
@override
51-
def _get_extra_local_objects(self, container_path: Path) -> List[str]:
51+
def _get_extra_local_objects(self, container_path: Path) -> List[CheckpointObjectId]:
5252
"""Returns extra local objects to include, specifically context files."""
5353
local_objects = []
5454
if self._recover_context:
5555
context_path = container_path / "context"
5656
if context_path.is_dir():
5757
for entry in os.listdir(context_path):
58-
local_objects.append(os.path.join("context", entry))
58+
local_objects.append(CheckpointObjectId(str(context_path / entry)))
5959
return local_objects
6060

6161
@override

src/ml_flashpoint/adapter/nemo/wrapper_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from ml_flashpoint.adapter.nemo.auto_resume import MLFlashpointAutoResume
2828
from ml_flashpoint.adapter.nemo.checkpoint_callback import MLFlashpointCheckpointCallback
2929
from ml_flashpoint.adapter.nemo.checkpoint_io import MLFlashpointAsyncFinalizableCheckpointIO, MLFlashpointCheckpointIO
30-
from ml_flashpoint.adapter.nemo.checkpoint_loader import NeMoMLFlashpointCheckpointLoader
30+
from ml_flashpoint.adapter.nemo.nemo_checkpoint_loader import NeMoMLFlashpointCheckpointLoader
3131
from ml_flashpoint.adapter.pytorch.memory_storage_writer import MemoryStorageWriter
3232
from ml_flashpoint.checkpoint_object_manager.checkpoint_object_manager import CheckpointObjectManager
3333
from ml_flashpoint.core.checkpoint_id_types import CheckpointContainerId

src/ml_flashpoint/core/checkpoint_loader.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,8 @@ def _compute_retrieval_plan(
379379
) -> Optional[dict[int, List[Tuple[int, str]]]]:
380380
"""Computes the retrieval plan.
381381
382+
The plan assumes an even number of ranks (accelerator processes) on each node in the training cluster.
383+
382384
Args:
383385
checkpoint: The checkpoint container ID.
384386
available_objects_by_rank: Map of rank to available objects on that rank.
@@ -540,15 +542,15 @@ def get_checkpoint_objects_by_rank(
540542
`CheckpointObjectId`s available on that node.
541543
"""
542544
container_path = Path(checkpoint_container_id.data)
543-
local_objects = []
545+
local_objects: List[CheckpointObjectId] = []
544546
if not container_path.is_dir():
545547
_LOGGER.debug(
546548
"Checkpoint container path '%s' is not a directory. Returning empty list.",
547549
container_path,
548550
)
549551
else:
550552
for entry in os.listdir(container_path):
551-
local_objects.append(entry)
553+
local_objects.append(CheckpointObjectId.from_container(checkpoint_container_id, entry))
552554

553555
local_objects.extend(self._get_extra_local_objects(container_path))
554556

@@ -560,11 +562,8 @@ def get_checkpoint_objects_by_rank(
560562
if all_objects_by_rank_paths:
561563
for rank, objects in enumerate(all_objects_by_rank_paths):
562564
if objects:
563-
# Convert filenames to full paths and then to CheckpointObjectId
564-
full_paths = [str(container_path / obj) for obj in objects]
565-
checkpoint_objects = [CheckpointObjectId(p) for p in full_paths]
566-
result[rank] = checkpoint_objects
567-
for obj in checkpoint_objects:
565+
result[rank] = objects
566+
for obj in objects:
568567
object_locations[obj.data].append(rank)
569568
else:
570569
result[rank] = []
@@ -628,13 +627,13 @@ def retrieve_checkpoint(
628627
return all(all_success_list)
629628

630629
def _get_extra_local_objects(self, container_path: Path) -> List[str]:
631-
"""Hook for subclasses to provide extra local objects that are available and relevant,
630+
"""Hook for subclasses to provide extra local objects that are available and relevant,
632631
which may be needed by other hosts.
633-
This can be used when additional objects beyond the standard checkpoint data are needed,
632+
This can be used when additional objects beyond the standard checkpoint data are needed,
634633
such as framework-specific context data.
635-
634+
636635
This should always be implemented alongside `_get_extra_needed_objects`.
637-
636+
638637
Returns:
639638
List of additional locally available objects.
640639
"""
@@ -646,11 +645,11 @@ def _get_extra_needed_objects(
646645
available_objects_by_rank: dict[int, List[CheckpointObjectId]],
647646
) -> Set[str]:
648647
"""Hook for subclasses to provide extra needed objects on any given host (each local rank 0).
649-
This can leverage `available_objects_by_rank` to determine the set of additional objects
648+
This can leverage `available_objects_by_rank` to determine the set of additional objects
650649
each host needs.
651-
650+
652651
This should always be implemented alongside `_get_extra_local_objects`.
653-
652+
654653
Returns:
655654
Set of extra needed objects on any given node.
656655
"""

tests/adapter/nemo/test_nemo_checkpoint_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import pytest
1919

20-
from ml_flashpoint.adapter.nemo.checkpoint_loader import NeMoMLFlashpointCheckpointLoader
20+
from ml_flashpoint.adapter.nemo.nemo_checkpoint_loader import NeMoMLFlashpointCheckpointLoader
2121
from ml_flashpoint.checkpoint_object_manager.checkpoint_object_manager import CheckpointObjectManager
2222
from ml_flashpoint.core.checkpoint_id_types import CheckpointContainerId, CheckpointObjectId
2323
from ml_flashpoint.replication.replication_manager import ReplicationManager

tests/adapter/nemo/test_wrapper_util.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,32 @@ def test_write_thread_count_forwarding(
256256
_, kwargs = spy_memory_storage_writer_init.call_args # Capture kwargs
257257
assert kwargs["thread_count"] == expected_thread_count
258258

259+
@pytest.mark.parametrize("always_save_context", [True, False])
260+
def test_loader_initialization_arguments(self, mocker, always_save_context):
261+
"""Tests that NeMoMLFlashpointCheckpointLoader is initialized with correct arguments."""
262+
# Given
263+
mocker.patch("ml_flashpoint.adapter.nemo.wrapper_util.ReplicationManager")
264+
mocker.patch("ml_flashpoint.adapter.nemo.wrapper_util.wrap_trainer_checkpoint_io_with_mlflashpoint")
265+
mock_loader_cls = mocker.patch("ml_flashpoint.adapter.nemo.wrapper_util.NeMoMLFlashpointCheckpointLoader")
266+
267+
trainer = mocker.MagicMock(spec=nl_trainer.Trainer)
268+
flashpoint_base_container = "/tmp/test_container"
269+
default_auto_resume = nl.AutoResume()
270+
271+
# When
272+
wrap_trainer_and_auto_resume_with_mlflashpoint(
273+
trainer,
274+
flashpoint_base_container,
275+
async_save=True,
276+
default_auto_resume=default_auto_resume,
277+
always_save_context=always_save_context,
278+
)
279+
280+
# Then
281+
mock_loader_cls.assert_called_once()
282+
_, kwargs = mock_loader_cls.call_args
283+
assert kwargs["recover_context"] == always_save_context
284+
259285

260286
class TestWrapTrainerCheckpointIOWithMLFlashpoint:
261287
"""Tests for the wrap_trainer_checkpoint_io_with_mlflashpoint function."""

tests/core/test_checkpoint_loader.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1725,8 +1725,9 @@ def test_get_checkpoint_objects_by_rank_optimization(self, loader, mocker, check
17251725

17261726
# Mock all_gather to return some data for rank 1
17271727
def side_effect_all_gather(out_list, in_obj):
1728-
out_list[0] = ["obj1", "obj2"] # Rank 0 data
1729-
out_list[1] = ["obj3"] # Rank 1 data (filename only)
1728+
out_list[0] = in_obj # Rank 0 data (passed in)
1729+
# Rank 1 data
1730+
out_list[1] = [CheckpointObjectId(str(ckpt_dir / "obj3"))]
17301731

17311732
mock_all_gather.side_effect = side_effect_all_gather
17321733

@@ -1735,11 +1736,13 @@ def side_effect_all_gather(out_list, in_obj):
17351736
# Call method
17361737
result = loader.get_checkpoint_objects_by_rank(container_id)
17371738

1738-
# Verify all_gather called with filenames only
1739+
# Verify all_gather called
17391740
mock_all_gather.assert_called_once()
17401741
args, _ = mock_all_gather.call_args
1741-
# args[1] is the input object (list of filenames for rank 0)
1742-
assert set(args[1]) == {"obj1", "obj2"}
1742+
# args[1] is the input object (list of CheckpointObjectIds for rank 0)
1743+
# Should be full paths now
1744+
expected_local = {str(ckpt_dir / "obj1"), str(ckpt_dir / "obj2")}
1745+
assert set(str(o.data) for o in args[1]) == expected_local
17431746

17441747
# Verify result has full paths
17451748
assert 0 in result
@@ -1913,9 +1916,10 @@ def test_get_checkpoint_objects_by_rank_sync(self, mocker, loader, tmp_path):
19131916
(container_path / "file1").touch()
19141917

19151918
def side_effect_all_gather(obj_list, local_obj):
1916-
obj_list[0] = ["file1"]
1917-
obj_list[1] = ["file1"]
1918-
obj_list[2] = ["file2"]
1919+
# Manually constructing full paths for the mock, mimicking what other ranks would send
1920+
obj_list[0] = [CheckpointObjectId(str(container_path / "file1"))]
1921+
obj_list[1] = [CheckpointObjectId(str(container_path / "file1"))]
1922+
obj_list[2] = [CheckpointObjectId(str(container_path / "file2"))]
19191923
obj_list[3] = []
19201924

19211925
mock_all_gather.side_effect = side_effect_all_gather

0 commit comments

Comments
 (0)