Skip to content

Commit 069c6e4

Browse files
committed
Recursively go through context/ dir.
1 parent 6bf5a3b commit 069c6e4

File tree

2 files changed

+43
-9
lines changed

2 files changed

+43
-9
lines changed

src/ml_flashpoint/adapter/nemo/nemo_checkpoint_loader.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,9 @@ def _get_extra_local_objects(self, container_path: Path) -> List[CheckpointObjec
5454
if self._recover_context:
5555
context_path = container_path / "context"
5656
if context_path.is_dir():
57-
for entry in os.listdir(context_path):
58-
local_objects.append(CheckpointObjectId(str(context_path / entry)))
57+
for root, _, files in os.walk(context_path):
58+
for file in files:
59+
local_objects.append(CheckpointObjectId(str(Path(root) / file)))
5960
return local_objects
6061

6162
@override
@@ -72,6 +73,10 @@ def _get_extra_needed_objects(
7273
context_path = Path(checkpoint.data) / "context"
7374
for objs in available_objects_by_rank.values():
7475
for obj in objs:
75-
if Path(str(obj.data)).parent == context_path:
76-
extra_needed.add(str(obj.data))
76+
try:
77+
if Path(str(obj.data)).is_relative_to(context_path):
78+
extra_needed.add(str(obj.data))
79+
except ValueError:
80+
# Path.is_relative_to raises ValueError if it's not relative to the path
81+
pass
7782
return extra_needed

tests/adapter/nemo/test_nemo_checkpoint_loader.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,28 +55,42 @@ def mock_listdir(path):
5555
return ["file1.txt", "file2.txt"]
5656
return []
5757

58+
def mock_walk(path):
59+
if str(path) == str(Path(container_path) / "context"):
60+
# Simulation of:
61+
# context/
62+
# file1.txt
63+
# file2.txt
64+
# subdir/
65+
# file3.txt
66+
# Yields: (root, dirs, files)
67+
yield (str(path), ["subdir"], ["file1.txt", "file2.txt"])
68+
yield (str(Path(path) / "subdir"), [], ["file3.txt"])
69+
return []
70+
5871
def mock_isdir(path):
5972
if str(path) == container_path:
6073
return True
6174
if str(path) == str(Path(container_path) / "context"):
6275
return True
6376
return False
6477

78+
mocker.patch("os.walk", side_effect=mock_walk)
6579
mocker.patch("os.listdir", side_effect=mock_listdir)
6680
mocker.patch("pathlib.Path.is_dir", new=mock_isdir)
6781

6882
result = loader.get_checkpoint_objects_by_rank(container_id)
6983

7084
assert 0 in result
7185
objs = result[0]
72-
# Should contain: .../context/file1.txt, .../context/file2.txt, .../context, .../other_file
73-
# Note: 'context' is from top-level listdir. 'context/file*.txt' is from recursive check.
7486

7587
paths = [str(o.data) for o in objs]
7688
expected_context_file1 = str(Path(container_path) / "context" / "file1.txt")
7789
expected_context_file2 = str(Path(container_path) / "context" / "file2.txt")
90+
expected_nested_file3 = str(Path(container_path) / "context" / "subdir" / "file3.txt")
7891
assert expected_context_file1 in paths
7992
assert expected_context_file2 in paths
93+
assert expected_nested_file3 in paths
8094

8195
def test_compute_retrieval_plan_includes_context_optimized(self, loader, mocker):
8296
"""
@@ -100,15 +114,26 @@ def test_compute_retrieval_plan_includes_context_optimized(self, loader, mocker)
100114
mocker.patch("torch.distributed.get_rank", return_value=0)
101115

102116
ctx_file = str(Path(checkpoint.data) / "context" / "file1.txt")
117+
nested_ctx_file = str(Path(checkpoint.data) / "context" / "subdir" / "file3.txt")
103118
common_pt = str(Path(checkpoint.data) / "common.pt")
104119
metadata_file = str(Path(checkpoint.data) / ".metadata")
105120

106121
# Available objects:
107-
# Node 0 (Rank 0,1) has everything (Context + Common + Metadata)
122+
# Node 0 (Rank 0,1) has everything (Context + Nested + Common + Metadata)
108123
# Node 1 (Rank 2,3) has nothing
109124
available_objects = {
110-
0: [CheckpointObjectId(ctx_file), CheckpointObjectId(common_pt), CheckpointObjectId(metadata_file)],
111-
1: [CheckpointObjectId(ctx_file), CheckpointObjectId(common_pt), CheckpointObjectId(metadata_file)],
125+
0: [
126+
CheckpointObjectId(ctx_file),
127+
CheckpointObjectId(nested_ctx_file),
128+
CheckpointObjectId(common_pt),
129+
CheckpointObjectId(metadata_file),
130+
],
131+
1: [
132+
CheckpointObjectId(ctx_file),
133+
CheckpointObjectId(nested_ctx_file),
134+
CheckpointObjectId(common_pt),
135+
CheckpointObjectId(metadata_file),
136+
],
112137
2: [],
113138
3: [],
114139
}
@@ -129,5 +154,9 @@ def test_compute_retrieval_plan_includes_context_optimized(self, loader, mocker)
129154
assert common_pt in retrieved_objs_2
130155
assert metadata_file in retrieved_objs_2
131156

157+
# Verify nested file
158+
nested_ctx_file = str(Path(checkpoint.data) / "context" / "subdir" / "file3.txt")
159+
assert nested_ctx_file in retrieved_objs_2
160+
132161
# Rank 3: Local rank 1 on Node 1. Shared FS with Rank 2. Should NOT retrieve.
133162
assert 3 not in plan or not plan[3]

0 commit comments

Comments
 (0)