Skip to content

Commit 3bfb113

Browse files
authored
Merge branch 'main' into abstract-distributed-apis-checkpoint-loader
2 parents bd621e7 + 7075ce4 commit 3bfb113

File tree

5 files changed

+57
-2
lines changed

5 files changed

+57
-2
lines changed

.github/workflows/build-and-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ jobs:
155155
echo ${{ github.event.number }} > pr_number.txt
156156
157157
- name: Archive coverage reports
158-
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # ratchet:actions/upload-artifact@v4
158+
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
159159
if: always()
160160
with:
161161
name: coverage-reports

.github/workflows/post-coverage-comment.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
# This workflow is independent as it requires write permissions, which GitHub blocks
16+
# on forks for security reasons. Thus, in order to be able to post comments with
17+
# code coverage details on PRs, we run a separate workflow in the context of the
18+
# base repository after any PR build workflow completes, relying on files uploaded
19+
# by the PR's build.
1520
name: Post Coverage Comment
1621

1722
on:

src/ml_flashpoint/adapter/nemo/wrapper_util.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,14 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint(
197197
+ f"'{checkpoint_io.__class__.__name__}'."
198198
)
199199

200+
# Use 'spawn' instead of 'fork' for the multiprocessing context.
201+
# By default, 'fork' causes the background SyncManager process to inherit
202+
# the parent's CUDA context. If the main training process is forcefully
203+
# killed (e.g., via SIGKILL during NVRX in-job restarts), the orphaned
204+
# manager process keeps the GPU memory locked, leading to CUDA Out-Of-Memory
205+
# (OOM) errors upon restart. 'spawn' launches a clean interpreter without
206+
# the inherited CUDA state, allowing the GPU memory to be freed instantly.
207+
ctx = torch_mp.get_context("spawn")
200208
save_strategy = MLFlashpointMegatronAsyncSaveStrategy(
201209
storage_writer=MemoryStorageWriter(
202210
checkpoint_saver=DefaultMLFlashpointCheckpointSaver(
@@ -208,7 +216,7 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint(
208216
initial_buffer_size_bytes=initial_write_buffer_size_bytes,
209217
use_optimized_save=use_optimized_save,
210218
),
211-
mp_manager=torch_mp.Manager(),
219+
mp_manager=ctx.Manager(),
212220
thread_count=write_thread_count,
213221
)
214222
)

src/ml_flashpoint/adapter/pytorch/memory_storage_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ def __init__(
9797
handling the actual checkpoint saving logic.
9898
mp_manager: A `torch.multiprocessing.Manager` instance for managing
9999
shared state across processes, particularly for write results and events.
100+
It is highly recommended to create this manager using a 'spawn'
101+
multiprocessing context to avoid inheriting the parent's CUDA context,
102+
which prevents CUDA OOM errors during failure recoveries
100103
thread_count: Optional. The number of threads to use for writing checkpoint data.
101104
Defaults to 1. If a value less than 1 is provided, it will be reset to 1,
102105
and a warning will be logged.

tests/adapter/nemo/test_wrapper_util.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,45 @@ def test_write_thread_count_forwarding(
794794
_, kwargs = spy_memory_storage_writer_init.call_args
795795
assert kwargs["thread_count"] == expected_thread_count
796796

797+
def test_spawn_context_used_for_mp_manager(self, mocker, mock_ckpt_obj_manager, mock_replication_manager):
798+
"""Tests that torch_mp.get_context('spawn').Manager() is correctly instantiated and passed."""
799+
# Given
800+
trainer = mocker.MagicMock(spec=nl_trainer.Trainer)
801+
trainer.callbacks = [mocker.MagicMock(spec=MLFlashpointCheckpointCallback)]
802+
trainer.strategy = mocker.MagicMock(spec=nl_strategies.MegatronStrategy)
803+
original_checkpoint_io = mocker.MagicMock(spec=MegatronCheckpointIO)
804+
trainer.strategy.checkpoint_io = original_checkpoint_io
805+
base_container = "/test_base_container"
806+
807+
mock_get_context = mocker.patch("ml_flashpoint.adapter.nemo.wrapper_util.torch_mp.get_context")
808+
809+
mock_ctx = mock_get_context.return_value # The mocked context object
810+
mock_manager_instance = mock_ctx.Manager.return_value # The mocked manager instance
811+
812+
spy_memory_storage_writer_init = mocker.spy(MemoryStorageWriter, "__init__")
813+
814+
# When
815+
wrap_trainer_checkpoint_io_with_mlflashpoint(
816+
trainer,
817+
base_container,
818+
mock_ckpt_obj_manager,
819+
mock_replication_manager,
820+
async_save=True,
821+
checkpoint_loader=mocker.MagicMock(spec=DefaultMLFlashpointCheckpointLoader),
822+
)
823+
824+
# Then
825+
# Verify get_context was called explicitly with 'spawn'
826+
mock_get_context.assert_called_once_with("spawn")
827+
828+
# Verify Manager() was called on the correct spawn context
829+
mock_ctx.Manager.assert_called_once()
830+
831+
# Verify the exact Manager instance was passed to MemoryStorageWriter
832+
spy_memory_storage_writer_init.assert_called_once()
833+
_, kwargs = spy_memory_storage_writer_init.call_args
834+
assert kwargs["mp_manager"] is mock_manager_instance
835+
797836
@pytest.mark.parametrize("always_save_context, expected_value", [(True, True), (False, False)])
798837
def test_always_save_context_forwarding(
799838
self, mocker, mock_ckpt_obj_manager, mock_replication_manager, always_save_context, expected_value

0 commit comments

Comments
 (0)