@@ -189,7 +189,7 @@ def test_async_save_initialization_calls_success(
189189
190190 mock_memory_storage_writer_cls .assert_called_once_with (
191191 checkpoint_saver = checkpoint_saver ,
192- mp_manager = storage_writer ._mp_manager ,
192+ mp_manager = storage_writer ._main_process_torchmp_manager ,
193193 thread_count = storage_writer ._thread_count ,
194194 )
195195 mock_new_storage_writer_instance .reset .assert_called_once_with (checkpoint_id .data )
@@ -229,7 +229,7 @@ def test_async_save_reinitializes_storage_writer_with_thread_count(
229229 # Then
230230 mock_memory_storage_writer_cls .assert_called_once_with (
231231 checkpoint_saver = checkpoint_saver ,
232- mp_manager = storage_writer ._mp_manager ,
232+ mp_manager = storage_writer ._main_process_torchmp_manager ,
233233 thread_count = expected_thread_count ,
234234 )
235235
@@ -275,7 +275,9 @@ def test_async_save_generate_plan_call_success(self, mocker, async_save_setup, s
275275 assert kwargs ["state_dict" ] == pyt_state_dict
276276 assert actual_storage_writer_used is not None
277277 assert isinstance (actual_storage_writer_used , MemoryStorageWriter )
278- assert actual_storage_writer_used ._mp_manager is storage_writer ._mp_manager
278+ assert (
279+ actual_storage_writer_used ._main_process_torchmp_manager is storage_writer ._main_process_torchmp_manager
280+ )
279281 assert kwargs ["planner" ] is mock_planner
280282 assert "world_dist_wrapper" in kwargs
281283 assert kwargs ["world_dist_wrapper" ].use_dist is False
@@ -372,8 +374,8 @@ def test_async_save_finalize_fns_calls(
372374 "ml_flashpoint.adapter.megatron.save_strategies.MemoryStorageWriter"
373375 )
374376 mock_storage_writer_instance = mock_memory_storage_writer_cls .return_value
375- # We need to set _mp_manager on the mock because the test asserts on it later
376- mock_storage_writer_instance ._mp_manager = storage_writer ._mp_manager
377+ # We need to set _main_process_torchmp_manager on the mock because the test asserts on it later
378+ mock_storage_writer_instance ._main_process_torchmp_manager = storage_writer ._main_process_torchmp_manager
377379 mock_storage_writer_instance .stage_write_data_buckets .return_value = dummy_write_buckets
378380
379381 expected_kwarg_keys = {"checkpoint_id" , "storage_writer" , "global_metadata" , "world_dist_wrapper" }
@@ -405,7 +407,9 @@ def test_async_save_finalize_fns_calls(
405407 assert kwargs ["checkpoint_id" ] == checkpoint_id
406408 assert actual_storage_writer_used is not None
407409 assert actual_storage_writer_used is mock_storage_writer_instance
408- assert actual_storage_writer_used ._mp_manager is storage_writer ._mp_manager
410+ assert (
411+ actual_storage_writer_used ._main_process_torchmp_manager is storage_writer ._main_process_torchmp_manager
412+ )
409413 assert kwargs ["global_metadata" ] == dummy_metadata
410414 assert kwargs ["world_dist_wrapper" ].use_dist is False
411415
@@ -433,3 +437,41 @@ def test_finalize_fns_failure(
433437
434438 # Then
435439 finalize_checkpoint_spy .assert_not_called ()
440+
441+ @pytest .mark .parametrize (
442+ "is_dist_initialized, dist_rank, expected_rank" ,
443+ [
444+ (True , 5 , 5 ),
445+ (False , 0 , - 1 ),
446+ ],
447+ )
448+ def test_async_save_rank_determination (
449+ self ,
450+ mocker ,
451+ async_save_setup ,
452+ is_dist_initialized ,
453+ dist_rank ,
454+ expected_rank ,
455+ ):
456+ """Tests that the rank passed to async_fn is correct based on dist initialization."""
457+ # Given
458+ strategy , checkpoint_id , sharded_state_dict , _ = async_save_setup
459+
460+ # Mock torch.distributed
461+ mocker .patch ("torch.distributed.is_initialized" , return_value = is_dist_initialized )
462+ if is_dist_initialized :
463+ mocker .patch ("torch.distributed.get_rank" , return_value = dist_rank )
464+
465+ # Mock dependencies to ensure success path
466+ mock_statedictsaver = mocker .patch ("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver" )
467+ mock_statedictsaver .generate_plan .return_value = (
468+ mocker .MagicMock (),
469+ mocker .MagicMock (),
470+ mocker .MagicMock (),
471+ )
472+
473+ # When
474+ actual_async_request = strategy .async_save (sharded_state_dict , checkpoint_id .data )
475+
476+ # Then
477+ assert actual_async_request .async_fn_kwargs ["rank" ] == expected_rank
0 commit comments