@@ -82,7 +82,7 @@ def error_run(weights: list[tuple[str, torch.Tensor]]):
8282 try :
8383 trigger_error (socket_paths )
8484 except RuntimeError as e :
85- assert str (e ) == "Failed to update weights due to remote errors "
85+ assert str (e ) == "Some workers failed to update weights "
8686
8787
8888def checker_proc (rank : int , device_uuid : str , named_tensors : dict [str , torch .Tensor ], queue : Queue ):
@@ -96,7 +96,7 @@ def check(names_to_check: dict[str, bool], weights: list[tuple[str, torch.Tensor
9696 for name , weight in weights :
9797 if name not in named_tensors :
9898 continue
99- assert (weight == named_tensors [name ]).all ()
99+ assert (weight == named_tensors [name ]).all (), f"Tensor { name } does not match!"
100100 names_to_check [name ] = True
101101
102102 def check_weights (names_to_check : dict [str , bool ], socket_paths : list [tuple [str , str ]]):
@@ -163,6 +163,61 @@ def run(
163163 assert proc .exitcode == 0
164164
165165
166+ def run_with_files (
167+ checker_func : callable ,
168+ ):
169+ rank = int (os .getenv ("RANK" ))
170+ ctx = get_context ("spawn" )
171+ queue = ctx .Queue ()
172+ _device_uuid = _get_physical_gpu_id (device_manager , rank )
173+ ps = ParameterServer (auto_pg = True )
174+ _device_uuid = _get_physical_gpu_id (ps .device_manager , rank )
175+ named_tensors = dict (gen_test_tensors (rank ))
176+
177+ # Save 1/3 tensors to /dev/shm/ as .safetensors files
178+ # Save 1/3 tensors to ./tmp (disk) as .safetensors files
179+ # Keep 1/3 tensors in memory
180+ import safetensors .torch
181+
182+ files = []
183+ dev_shm_dir = "/dev/shm/checkpoint_engine_tests" # noqa: S108
184+ disk_dir = "/tmp/checkpoint_engine_tests" # noqa: S108
185+ os .makedirs (dev_shm_dir , exist_ok = True )
186+ os .makedirs (disk_dir , exist_ok = True )
187+ tensors_items = list (named_tensors .items ())
188+ tensors_in_dev_shm = named_tensors
189+ tensors_in_dev_shm = dict (tensors_items [: len (tensors_items ) // 2 ])
190+ tensors_in_disk = dict (tensors_items [len (tensors_items ) // 3 : 2 * len (tensors_items ) // 3 ])
191+ tensors_in_memory = dict (tensors_items [1 * len (tensors_items ) // 2 :])
192+ disk_files = [
193+ os .path .join (disk_dir , f"rank{ _rank } _checkpoint.safetensors" )
194+ for _rank in range (get_world_size ())
195+ ]
196+ safetensors .torch .save_file (tensors_in_disk , disk_files [rank ])
197+ time .sleep (1 )
198+ files .append (disk_files [rank ])
199+ dev_shm_files = [
200+ os .path .join (dev_shm_dir , f"rank{ rank } _checkpoint.safetensors" )
201+ for _ in range (get_world_size ())
202+ ]
203+ safetensors .torch .save_file (tensors_in_dev_shm , dev_shm_files [rank ])
204+ time .sleep (1 )
205+ files .append (dev_shm_files [rank ])
206+
207+ checkpoint_name = "test_with_files"
208+ proc = ctx .Process (target = checker_func , args = (rank , _device_uuid , named_tensors , queue ))
209+ proc .start ()
210+ ps .register_checkpoint (checkpoint_name , named_tensors = tensors_in_memory , files = files )
211+ ps .gather_metas (checkpoint_name )
212+ ps .update (checkpoint_name , queue .put , ranks = [])
213+ # sleep 3s to wait process group is destroyed
214+ time .sleep (3 )
215+ ps .unregister_checkpoint (checkpoint_name )
216+ queue .put (None )
217+ proc .join ()
218+ assert proc .exitcode == 0
219+
220+
166221@pytest .mark .gpu
167222@pytest .mark .parametrize (
168223 "test_name,rank_list" ,
@@ -211,6 +266,37 @@ def test_update(test_name: str, rank_list: list[list[int]] | None):
211266 assert result .returncode == 0
212267
213268
269+ @pytest .mark .gpu
270+ def test_update_with_files (test_name : str = "test_with_files" ):
271+ world_size = device_manager .device_module .device_count ()
272+ assert world_size >= 2 , "This test requires at least 2 GPUs."
273+ master_addr = "localhost"
274+ master_port = 25400
275+ cmd = [
276+ "torchrun" ,
277+ "--nproc_per_node" ,
278+ str (world_size ),
279+ "--master_addr" ,
280+ master_addr ,
281+ "--master_port" ,
282+ str (master_port ),
283+ __file__ ,
284+ test_name ,
285+ "[]" ,
286+ ]
287+
288+ result = subprocess .run ( # noqa: S603
289+ cmd ,
290+ capture_output = False ,
291+ text = True ,
292+ cwd = os .path .dirname (os .path .dirname (os .path .abspath (__file__ ))),
293+ shell = False ,
294+ check = False ,
295+ )
296+
297+ assert result .returncode == 0
298+
299+
214300if __name__ == "__main__" :
215301 run_with_pytest = "PYTEST_CURRENT_TEST" in os .environ
216302 if not run_with_pytest :
@@ -230,5 +316,7 @@ def test_update(test_name: str, rank_list: list[list[int]] | None):
230316 expected_exception = RuntimeError ,
231317 exception_msg = "Failed to update weights due to remote errors" ,
232318 )
319+ elif test_type == "test_with_files" :
320+ run_with_files (checker_proc )
233321 else :
234322 raise ValueError (f"Unknown TEST_TYPE: { test_type } " )
0 commit comments