@@ -794,6 +794,45 @@ def test_write_thread_count_forwarding(
794794 _ , kwargs = spy_memory_storage_writer_init .call_args
795795 assert kwargs ["thread_count" ] == expected_thread_count
796796
797+ def test_spawn_context_used_for_mp_manager (self , mocker , mock_ckpt_obj_manager , mock_replication_manager ):
798+ """Tests that torch_mp.get_context('spawn').Manager() is correctly instantiated and passed."""
799+ # Given
800+ trainer = mocker .MagicMock (spec = nl_trainer .Trainer )
801+ trainer .callbacks = [mocker .MagicMock (spec = MLFlashpointCheckpointCallback )]
802+ trainer .strategy = mocker .MagicMock (spec = nl_strategies .MegatronStrategy )
803+ original_checkpoint_io = mocker .MagicMock (spec = MegatronCheckpointIO )
804+ trainer .strategy .checkpoint_io = original_checkpoint_io
805+ base_container = "/test_base_container"
806+
807+ mock_get_context = mocker .patch ("ml_flashpoint.adapter.nemo.wrapper_util.torch_mp.get_context" )
808+
809+ mock_ctx = mock_get_context .return_value # The mocked context object
810+ mock_manager_instance = mock_ctx .Manager .return_value # The mocked manager instance
811+
812+ spy_memory_storage_writer_init = mocker .spy (MemoryStorageWriter , "__init__" )
813+
814+ # When
815+ wrap_trainer_checkpoint_io_with_mlflashpoint (
816+ trainer ,
817+ base_container ,
818+ mock_ckpt_obj_manager ,
819+ mock_replication_manager ,
820+ async_save = True ,
821+ checkpoint_loader = mocker .MagicMock (spec = DefaultMLFlashpointCheckpointLoader ),
822+ )
823+
824+ # Then
825+ # Verify get_context was called explicitly with 'spawn'
826+ mock_get_context .assert_called_once_with ("spawn" )
827+
828+ # Verify Manager() was called on the correct spawn context
829+ mock_ctx .Manager .assert_called_once ()
830+
831+ # Verify the exact Manager instance was passed to MemoryStorageWriter
832+ spy_memory_storage_writer_init .assert_called_once ()
833+ _ , kwargs = spy_memory_storage_writer_init .call_args
834+ assert kwargs ["mp_manager" ] is mock_manager_instance
835+
797836 @pytest .mark .parametrize ("always_save_context, expected_value" , [(True , True ), (False , False )])
798837 def test_always_save_context_forwarding (
799838 self , mocker , mock_ckpt_obj_manager , mock_replication_manager , always_save_context , expected_value
0 commit comments