@@ -55,28 +55,42 @@ def mock_listdir(path):
5555 return ["file1.txt" , "file2.txt" ]
5656 return []
5757
58+ def mock_walk (path ):
59+ if str (path ) == str (Path (container_path ) / "context" ):
60+ # Simulation of:
61+ # context/
62+ # file1.txt
63+ # file2.txt
64+ # subdir/
65+ # file3.txt
66+ # Yields: (root, dirs, files)
67+ yield (str (path ), ["subdir" ], ["file1.txt" , "file2.txt" ])
68+ yield (str (Path (path ) / "subdir" ), [], ["file3.txt" ])
69+ return []
70+
5871 def mock_isdir (path ):
5972 if str (path ) == container_path :
6073 return True
6174 if str (path ) == str (Path (container_path ) / "context" ):
6275 return True
6376 return False
6477
78+ mocker .patch ("os.walk" , side_effect = mock_walk )
6579 mocker .patch ("os.listdir" , side_effect = mock_listdir )
6680 mocker .patch ("pathlib.Path.is_dir" , new = mock_isdir )
6781
6882 result = loader .get_checkpoint_objects_by_rank (container_id )
6983
7084 assert 0 in result
7185 objs = result [0 ]
72- # Should contain: .../context/file1.txt, .../context/file2.txt, .../context, .../other_file
73- # Note: 'context' is from top-level listdir. 'context/file*.txt' is from recursive check.
7486
7587 paths = [str (o .data ) for o in objs ]
7688 expected_context_file1 = str (Path (container_path ) / "context" / "file1.txt" )
7789 expected_context_file2 = str (Path (container_path ) / "context" / "file2.txt" )
90+ expected_nested_file3 = str (Path (container_path ) / "context" / "subdir" / "file3.txt" )
7891 assert expected_context_file1 in paths
7992 assert expected_context_file2 in paths
93+ assert expected_nested_file3 in paths
8094
8195 def test_compute_retrieval_plan_includes_context_optimized (self , loader , mocker ):
8296 """
@@ -100,15 +114,26 @@ def test_compute_retrieval_plan_includes_context_optimized(self, loader, mocker)
100114 mocker .patch ("torch.distributed.get_rank" , return_value = 0 )
101115
102116 ctx_file = str (Path (checkpoint .data ) / "context" / "file1.txt" )
117+ nested_ctx_file = str (Path (checkpoint .data ) / "context" / "subdir" / "file3.txt" )
103118 common_pt = str (Path (checkpoint .data ) / "common.pt" )
104119 metadata_file = str (Path (checkpoint .data ) / ".metadata" )
105120
106121 # Available objects:
107- # Node 0 (Rank 0,1) has everything (Context + Common + Metadata)
122+ # Node 0 (Rank 0,1) has everything (Context + Nested + Common + Metadata)
108123 # Node 1 (Rank 2,3) has nothing
109124 available_objects = {
110- 0 : [CheckpointObjectId (ctx_file ), CheckpointObjectId (common_pt ), CheckpointObjectId (metadata_file )],
111- 1 : [CheckpointObjectId (ctx_file ), CheckpointObjectId (common_pt ), CheckpointObjectId (metadata_file )],
125+ 0 : [
126+ CheckpointObjectId (ctx_file ),
127+ CheckpointObjectId (nested_ctx_file ),
128+ CheckpointObjectId (common_pt ),
129+ CheckpointObjectId (metadata_file ),
130+ ],
131+ 1 : [
132+ CheckpointObjectId (ctx_file ),
133+ CheckpointObjectId (nested_ctx_file ),
134+ CheckpointObjectId (common_pt ),
135+ CheckpointObjectId (metadata_file ),
136+ ],
112137 2 : [],
113138 3 : [],
114139 }
@@ -129,5 +154,9 @@ def test_compute_retrieval_plan_includes_context_optimized(self, loader, mocker)
129154 assert common_pt in retrieved_objs_2
130155 assert metadata_file in retrieved_objs_2
131156
157+ # Verify nested file
158+ nested_ctx_file = str (Path (checkpoint .data ) / "context" / "subdir" / "file3.txt" )
159+ assert nested_ctx_file in retrieved_objs_2
160+
132161 # Rank 3: Local rank 1 on Node 1. Shared FS with Rank 2. Should NOT retrieve.
133162 assert 3 not in plan or not plan [3 ]
0 commit comments