Skip to content

Commit 066f2a3

Browse files
Leahlijuang-husam
andauthored
refactor(core): Move context recovery logic strictly to the NeMo layer (#19)
- Extracts context recovery logic from DefaultMLFlashpointCheckpointLoader - Introduces _get_extra_local_objects and _get_extra_needed_objects to DefaultMLFlashpointCheckpointLoader - Updates NeMo wrapper to instantiate the new NeMoMLFlashpointCheckpointLoader --------- Co-authored-by: g-husam <husameldawi@google.com>
1 parent 14b9566 commit 066f2a3

File tree

6 files changed

+329
-148
lines changed

6 files changed

+329
-148
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
from pathlib import Path
17+
from typing import List, Set
18+
19+
from typing_extensions import override
20+
21+
from ml_flashpoint.checkpoint_object_manager.checkpoint_object_manager import CheckpointObjectManager
22+
from ml_flashpoint.core.checkpoint_id_types import CheckpointContainerId, CheckpointObjectId
23+
from ml_flashpoint.core.checkpoint_loader import DefaultMLFlashpointCheckpointLoader
24+
from ml_flashpoint.replication.replication_manager import ReplicationManager
25+
26+
27+
class NeMoMLFlashpointCheckpointLoader(DefaultMLFlashpointCheckpointLoader):
28+
"""
29+
NeMo-specific implementation of the MLFlashpointCheckpointLoader interface.
30+
"""
31+
32+
def __init__(
33+
self,
34+
checkpoint_object_manager: CheckpointObjectManager,
35+
replication_manager: ReplicationManager,
36+
recover_context: bool = False,
37+
):
38+
"""Initializes the NeMoMLFlashpointCheckpointLoader.
39+
40+
Args:
41+
checkpoint_object_manager: The checkpoint object manager to use for
42+
reading data.
43+
replication_manager: The replication manager to use for retrieving
44+
missing checkpoint objects from peer nodes.
45+
recover_context: Whether to recover the context directory if missing.
46+
"""
47+
super().__init__(checkpoint_object_manager, replication_manager)
48+
self._recover_context = recover_context
49+
50+
@override
51+
def _get_extra_local_objects(self, container_path: Path) -> List[CheckpointObjectId]:
52+
local_objects = []
53+
if self._recover_context:
54+
context_path = container_path / "context"
55+
if context_path.is_dir():
56+
for root, _, files in os.walk(context_path):
57+
for file in files:
58+
local_objects.append(CheckpointObjectId(str(Path(root) / file)))
59+
return local_objects
60+
61+
@override
62+
def _get_extra_needed_objects(
63+
self,
64+
checkpoint: CheckpointContainerId,
65+
available_objects_by_rank: dict[int, List[CheckpointObjectId]],
66+
) -> Set[str]:
67+
extra_needed = set()
68+
if self._recover_context:
69+
# We assume that if a rank has the context dir, the content in the dir is complete.
70+
# We assume that these are the files needed by all the nodes.
71+
context_path = Path(checkpoint.data) / "context"
72+
for objs in available_objects_by_rank.values():
73+
for obj in objs:
74+
try:
75+
if Path(str(obj.data)).is_relative_to(context_path):
76+
extra_needed.add(str(obj.data))
77+
except ValueError:
78+
# Path.is_relative_to raises ValueError if it's not relative to the path
79+
pass
80+
return extra_needed

src/ml_flashpoint/adapter/nemo/wrapper_util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from ml_flashpoint.adapter.nemo.auto_resume import MLFlashpointAutoResume
2828
from ml_flashpoint.adapter.nemo.checkpoint_callback import MLFlashpointCheckpointCallback
2929
from ml_flashpoint.adapter.nemo.checkpoint_io import MLFlashpointAsyncFinalizableCheckpointIO, MLFlashpointCheckpointIO
30+
from ml_flashpoint.adapter.nemo.nemo_checkpoint_loader import NeMoMLFlashpointCheckpointLoader
3031
from ml_flashpoint.adapter.pytorch.memory_storage_writer import MemoryStorageWriter
3132
from ml_flashpoint.checkpoint_object_manager.checkpoint_object_manager import CheckpointObjectManager
3233
from ml_flashpoint.core.checkpoint_id_types import CheckpointContainerId
@@ -72,7 +73,7 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint(
7273
replication_manager = ReplicationManager()
7374
replication_manager.initialize(checkpoint_object_manager=ckpt_obj_manager)
7475

75-
ckpt_loader = DefaultMLFlashpointCheckpointLoader(
76+
ckpt_loader = NeMoMLFlashpointCheckpointLoader(
7677
checkpoint_object_manager=ckpt_obj_manager,
7778
replication_manager=replication_manager,
7879
recover_context=always_save_context,

src/ml_flashpoint/core/checkpoint_loader.py

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import struct
2323
from collections import defaultdict
2424
from pathlib import Path
25-
from typing import IO, List, Optional, Tuple, TypeVar, cast
25+
from typing import IO, List, Optional, Set, Tuple, TypeVar, cast
2626

2727
import torch
2828
import torch.distributed as dist
@@ -128,7 +128,6 @@ def __init__(
128128
self,
129129
checkpoint_object_manager: CheckpointObjectManager,
130130
replication_manager: ReplicationManager,
131-
recover_context: bool = False,
132131
):
133132
"""Initializes the DefaultMLFlashpointCheckpointLoader.
134133
@@ -137,11 +136,9 @@ def __init__(
137136
reading data.
138137
replication_manager: The replication manager to use for retrieving
139138
missing checkpoint objects from peer nodes.
140-
recover_context: Whether to recover the context directory if missing.
141139
"""
142140
self._checkpoint_object_manager = checkpoint_object_manager
143141
self._replication_manager = replication_manager
144-
self._recover_context = recover_context
145142
# Cache for available objects: CheckpointContainerId -> dict[object_path, list[rank]]
146143
self._available_objects_cache: dict[CheckpointContainerId, dict[str, List[int]]] = {}
147144

@@ -417,6 +414,8 @@ def _compute_retrieval_plan(
417414
) -> Optional[dict[int, List[Tuple[int, str]]]]:
418415
"""Computes the retrieval plan.
419416
417+
The plan assumes an even number of ranks (accelerator processes) on each node in the training cluster.
418+
420419
Args:
421420
checkpoint: The checkpoint container ID.
422421
available_objects_by_rank: Map of rank to available objects on that rank.
@@ -450,14 +449,7 @@ def _compute_retrieval_plan(
450449
str(CheckpointObjectId.from_container(checkpoint, default_metadata_object_name()))
451450
)
452451

453-
if self._recover_context:
454-
# We assume that if a rank has the context dir, the content in the dir is complete.
455-
# We assume that are the files needed by all the nodes.
456-
context_path = Path(checkpoint.data) / "context"
457-
for objs in available_objects_by_rank.values():
458-
for obj in objs:
459-
if Path(obj.data).parent == context_path:
460-
objects_needed_by_local_rank_0.add(str(obj.data))
452+
objects_needed_by_local_rank_0.update(self._get_extra_needed_objects(checkpoint, available_objects_by_rank))
461453

462454
world_size = dist.get_world_size()
463455
num_nodes = get_num_of_nodes()
@@ -581,25 +573,21 @@ def get_checkpoint_objects_by_rank(
581573
checkpoint_container_id: The ID of the checkpoint container to inspect.
582574
583575
Returns:
584-
A dictionary mapping each node's local rank 0 to a list of
585-
`CheckpointObjectId`s available on that node.
576+
A dictionary mapping each rank to a list of
577+
`CheckpointObjectId`s available on that rank.
586578
"""
587579
container_path = Path(checkpoint_container_id.data)
588-
local_objects = []
580+
local_objects: List[CheckpointObjectId] = []
589581
if not container_path.is_dir():
590582
_LOGGER.debug(
591583
"Checkpoint container path '%s' is not a directory. Returning empty list.",
592584
container_path,
593585
)
594586
else:
595587
for entry in os.listdir(container_path):
596-
local_objects.append(entry)
588+
local_objects.append(CheckpointObjectId.from_container(checkpoint_container_id, entry))
597589

598-
if self._recover_context:
599-
context_path = container_path / "context"
600-
if context_path.is_dir():
601-
for entry in os.listdir(context_path):
602-
local_objects.append(os.path.join("context", entry))
590+
local_objects.extend(self._get_extra_local_objects(container_path))
603591

604592
all_objects_by_rank_paths = [None for _ in range(dist.get_world_size())]
605593
dist.all_gather_object(all_objects_by_rank_paths, local_objects)
@@ -609,11 +597,8 @@ def get_checkpoint_objects_by_rank(
609597
if all_objects_by_rank_paths:
610598
for rank, objects in enumerate(all_objects_by_rank_paths):
611599
if objects:
612-
# Convert filenames to full paths and then to CheckpointObjectId
613-
full_paths = [str(container_path / obj) for obj in objects]
614-
checkpoint_objects = [CheckpointObjectId(p) for p in full_paths]
615-
result[rank] = checkpoint_objects
616-
for obj in checkpoint_objects:
600+
result[rank] = objects
601+
for obj in objects:
617602
object_locations[obj.data].append(rank)
618603
else:
619604
result[rank] = []
@@ -675,3 +660,37 @@ def retrieve_checkpoint(
675660
dist.all_gather_object(all_success_list, all_success)
676661
_LOGGER.debug("All success list: '%s'", all_success_list)
677662
return all(all_success_list)
663+
664+
def _get_extra_local_objects(self, container_path: Path) -> List[CheckpointObjectId]:
665+
"""Hook for subclasses to provide extra local objects that are available on the current rank,
666+
which may be needed by other ranks. This would be called on every rank.
667+
668+
This can be used when additional objects beyond the standard checkpoint data are needed,
669+
such as framework-specific context data.
670+
671+
This should always be implemented alongside `_get_extra_needed_objects`.
672+
673+
Returns:
674+
List of additional locally available objects.
675+
"""
676+
return []
677+
678+
def _get_extra_needed_objects(
679+
self,
680+
checkpoint: CheckpointContainerId,
681+
available_objects_by_rank: dict[int, List[CheckpointObjectId]],
682+
) -> Set[str]:
683+
"""Hook for subclasses to provide extra needed objects for each node.
684+
685+
The objects returned by this method are considered necessary for the first rank
686+
(local rank 0) of every node.
687+
688+
This can leverage `available_objects_by_rank` to determine the set of additional objects
689+
needed.
690+
691+
This should always be implemented alongside `_get_extra_local_objects`.
692+
693+
Returns:
694+
Set of extra needed objects on each node (specifically local rank 0).
695+
"""
696+
return set()
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from pathlib import Path
16+
from unittest.mock import MagicMock
17+
18+
import pytest
19+
20+
from ml_flashpoint.adapter.nemo.nemo_checkpoint_loader import NeMoMLFlashpointCheckpointLoader
21+
from ml_flashpoint.checkpoint_object_manager.checkpoint_object_manager import CheckpointObjectManager
22+
from ml_flashpoint.core.checkpoint_id_types import CheckpointContainerId, CheckpointObjectId
23+
from ml_flashpoint.replication.replication_manager import ReplicationManager
24+
25+
26+
class TestNeMoCheckpointLoaderContext:
27+
@pytest.fixture
28+
def loader(self):
29+
ckpt_manager = CheckpointObjectManager()
30+
repl_manager = MagicMock(spec=ReplicationManager)
31+
return NeMoMLFlashpointCheckpointLoader(
32+
checkpoint_object_manager=ckpt_manager, replication_manager=repl_manager, recover_context=True
33+
)
34+
35+
def test_get_checkpoint_objects_by_rank_finds_context(self, loader, mocker):
36+
"""Test that get_checkpoint_objects_by_rank finds files in context/ dir when recover_context=True."""
37+
mocker.patch("torch.distributed.get_world_size", return_value=1)
38+
mocker.patch(
39+
"torch.distributed.all_gather_object",
40+
side_effect=lambda obj_list, local_obj: obj_list.__setitem__(0, local_obj),
41+
)
42+
# Mock get_node_local_rank to avoid external dependency issues if called
43+
mocker.patch("torch.distributed.get_node_local_rank", return_value=0)
44+
45+
container_path = "/tmp/ckpt/step-1"
46+
container_id = CheckpointContainerId(container_path)
47+
48+
# Mock fs
49+
# base dir has "context" directory
50+
# context dir has "file1.txt"
51+
def mock_listdir(path):
52+
if str(path) == container_path:
53+
return ["context", "other_file"]
54+
if str(path) == str(Path(container_path) / "context"):
55+
return ["file1.txt", "file2.txt"]
56+
return []
57+
58+
def mock_walk(path):
59+
if str(path) == str(Path(container_path) / "context"):
60+
# Simulation of:
61+
# context/
62+
# file1.txt
63+
# file2.txt
64+
# subdir/
65+
# file3.txt
66+
# Yields: (root, dirs, files)
67+
yield (str(path), ["subdir"], ["file1.txt", "file2.txt"])
68+
yield (str(Path(path) / "subdir"), [], ["file3.txt"])
69+
return []
70+
71+
def mock_isdir(path):
72+
if str(path) == container_path:
73+
return True
74+
if str(path) == str(Path(container_path) / "context"):
75+
return True
76+
return False
77+
78+
mocker.patch("os.walk", side_effect=mock_walk)
79+
mocker.patch("os.listdir", side_effect=mock_listdir)
80+
mocker.patch("pathlib.Path.is_dir", new=mock_isdir)
81+
82+
result = loader.get_checkpoint_objects_by_rank(container_id)
83+
84+
assert 0 in result
85+
objs = result[0]
86+
87+
paths = [str(o.data) for o in objs]
88+
expected_context_file1 = str(Path(container_path) / "context" / "file1.txt")
89+
expected_context_file2 = str(Path(container_path) / "context" / "file2.txt")
90+
expected_nested_file3 = str(Path(container_path) / "context" / "subdir" / "file3.txt")
91+
assert expected_context_file1 in paths
92+
assert expected_context_file2 in paths
93+
assert expected_nested_file3 in paths
94+
95+
def test_compute_retrieval_plan_includes_context_optimized(self, loader, mocker):
96+
"""
97+
Test that _compute_retrieval_plan includes context files ONLY for local rank 0 on each node.
98+
Scenario:
99+
- World Size: 4
100+
- Nodes: 2 (Ranks 0,1 on Node 0; Ranks 2,3 on Node 1)
101+
- Rank 0 has context files.
102+
- Rank 2 needs context files (different node).
103+
- Rank 1, 3 do NOT need context retrieval (same node as 0, 2).
104+
"""
105+
checkpoint = CheckpointContainerId("/tmp/ckpt/step-1")
106+
107+
# Mock metadata read (empty storage data)
108+
mock_metadata = MagicMock()
109+
mock_metadata.storage_data = {}
110+
mocker.patch.object(loader, "read_metadata", return_value=mock_metadata)
111+
112+
mocker.patch("torch.distributed.get_world_size", return_value=4)
113+
mocker.patch("ml_flashpoint.core.checkpoint_loader.get_num_of_nodes", return_value=2)
114+
mocker.patch("torch.distributed.get_rank", return_value=0)
115+
116+
ctx_file = str(Path(checkpoint.data) / "context" / "file1.txt")
117+
nested_ctx_file = str(Path(checkpoint.data) / "context" / "subdir" / "file3.txt")
118+
common_pt = str(Path(checkpoint.data) / "common.pt")
119+
metadata_file = str(Path(checkpoint.data) / ".metadata")
120+
121+
# Available objects:
122+
# Node 0 (Rank 0,1) has everything (Context + Nested + Common + Metadata)
123+
# Node 1 (Rank 2,3) has nothing
124+
available_objects = {
125+
0: [
126+
CheckpointObjectId(ctx_file),
127+
CheckpointObjectId(nested_ctx_file),
128+
CheckpointObjectId(common_pt),
129+
CheckpointObjectId(metadata_file),
130+
],
131+
1: [
132+
CheckpointObjectId(ctx_file),
133+
CheckpointObjectId(nested_ctx_file),
134+
CheckpointObjectId(common_pt),
135+
CheckpointObjectId(metadata_file),
136+
],
137+
2: [],
138+
3: [],
139+
}
140+
141+
# Execute
142+
plan = loader._compute_retrieval_plan(checkpoint, available_objects)
143+
144+
assert plan is not None
145+
146+
# Node 0: Already has files, no retrieval needed (or plan[0] is empty)
147+
assert 0 not in plan or not plan[0]
148+
assert 1 not in plan or not plan[1]
149+
150+
# Rank 2: Local rank 0 on Node 1. Needs Context + Common + Metadata.
151+
assert 2 in plan
152+
retrieved_objs_2 = [path for src, path in plan[2]]
153+
assert ctx_file in retrieved_objs_2
154+
assert common_pt in retrieved_objs_2
155+
assert metadata_file in retrieved_objs_2
156+
157+
# Verify nested file
158+
nested_ctx_file = str(Path(checkpoint.data) / "context" / "subdir" / "file3.txt")
159+
assert nested_ctx_file in retrieved_objs_2
160+
161+
# Rank 3: Local rank 1 on Node 1. Shared FS with Rank 2. Should NOT retrieve.
162+
assert 3 not in plan or not plan[3]

0 commit comments

Comments
 (0)