Skip to content

Commit 7075ce4

Browse files
authored
fix(adapter/nemo): use spawn context for mp_manager to prevent OOM on NVRX restart (#59)
1 parent 7bb914f commit 7075ce4

File tree

3 files changed

+51
-1
lines changed

3 files changed

+51
-1
lines changed

src/ml_flashpoint/adapter/nemo/wrapper_util.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,14 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint(
197197
+ f"'{checkpoint_io.__class__.__name__}'."
198198
)
199199

200+
# Use 'spawn' instead of 'fork' for the multiprocessing context.
201+
# By default, 'fork' causes the background SyncManager process to inherit
202+
# the parent's CUDA context. If the main training process is forcefully
203+
# killed (e.g., via SIGKILL during NVRX in-job restarts), the orphaned
204+
# manager process keeps the GPU memory locked, leading to CUDA Out-Of-Memory
205+
# (OOM) errors upon restart. 'spawn' launches a clean interpreter without
206+
# the inherited CUDA state, allowing the GPU memory to be freed instantly.
207+
ctx = torch_mp.get_context("spawn")
200208
save_strategy = MLFlashpointMegatronAsyncSaveStrategy(
201209
storage_writer=MemoryStorageWriter(
202210
checkpoint_saver=DefaultMLFlashpointCheckpointSaver(
@@ -208,7 +216,7 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint(
208216
initial_buffer_size_bytes=initial_write_buffer_size_bytes,
209217
use_optimized_save=use_optimized_save,
210218
),
211-
mp_manager=torch_mp.Manager(),
219+
mp_manager=ctx.Manager(),
212220
thread_count=write_thread_count,
213221
)
214222
)

src/ml_flashpoint/adapter/pytorch/memory_storage_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ def __init__(
9797
handling the actual checkpoint saving logic.
9898
mp_manager: A `torch.multiprocessing.Manager` instance for managing
9999
shared state across processes, particularly for write results and events.
100+
It is highly recommended to create this manager using a 'spawn'
101+
multiprocessing context to avoid inheriting the parent's CUDA context,
102+
which prevents CUDA OOM errors during failure recoveries
100103
thread_count: Optional. The number of threads to use for writing checkpoint data.
101104
Defaults to 1. If a value less than 1 is provided, it will be reset to 1,
102105
and a warning will be logged.

tests/adapter/nemo/test_wrapper_util.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,45 @@ def test_write_thread_count_forwarding(
794794
_, kwargs = spy_memory_storage_writer_init.call_args
795795
assert kwargs["thread_count"] == expected_thread_count
796796

797+
def test_spawn_context_used_for_mp_manager(self, mocker, mock_ckpt_obj_manager, mock_replication_manager):
798+
"""Tests that torch_mp.get_context('spawn').Manager() is correctly instantiated and passed."""
799+
# Given
800+
trainer = mocker.MagicMock(spec=nl_trainer.Trainer)
801+
trainer.callbacks = [mocker.MagicMock(spec=MLFlashpointCheckpointCallback)]
802+
trainer.strategy = mocker.MagicMock(spec=nl_strategies.MegatronStrategy)
803+
original_checkpoint_io = mocker.MagicMock(spec=MegatronCheckpointIO)
804+
trainer.strategy.checkpoint_io = original_checkpoint_io
805+
base_container = "/test_base_container"
806+
807+
mock_get_context = mocker.patch("ml_flashpoint.adapter.nemo.wrapper_util.torch_mp.get_context")
808+
809+
mock_ctx = mock_get_context.return_value # The mocked context object
810+
mock_manager_instance = mock_ctx.Manager.return_value # The mocked manager instance
811+
812+
spy_memory_storage_writer_init = mocker.spy(MemoryStorageWriter, "__init__")
813+
814+
# When
815+
wrap_trainer_checkpoint_io_with_mlflashpoint(
816+
trainer,
817+
base_container,
818+
mock_ckpt_obj_manager,
819+
mock_replication_manager,
820+
async_save=True,
821+
checkpoint_loader=mocker.MagicMock(spec=DefaultMLFlashpointCheckpointLoader),
822+
)
823+
824+
# Then
825+
# Verify get_context was called explicitly with 'spawn'
826+
mock_get_context.assert_called_once_with("spawn")
827+
828+
# Verify Manager() was called on the correct spawn context
829+
mock_ctx.Manager.assert_called_once()
830+
831+
# Verify the exact Manager instance was passed to MemoryStorageWriter
832+
spy_memory_storage_writer_init.assert_called_once()
833+
_, kwargs = spy_memory_storage_writer_init.call_args
834+
assert kwargs["mp_manager"] is mock_manager_instance
835+
797836
@pytest.mark.parametrize("always_save_context, expected_value", [(True, True), (False, False)])
798837
def test_always_save_context_forwarding(
799838
self, mocker, mock_ckpt_obj_manager, mock_replication_manager, always_save_context, expected_value

0 commit comments

Comments
 (0)