Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/ml_flashpoint/adapter/nemo/wrapper_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
from typing import Union

import torch
from megatron.core.dist_checkpointing.strategies.fully_parallel import (
FullyParallelLoadStrategyWrapper,
FullyParallelSaveStrategyWrapper,
)
from nemo import lightning as nl
from nemo.lightning.io.pl import MegatronCheckpointIO
from nemo.lightning.pytorch import strategies as nl_strategies
Expand Down Expand Up @@ -45,6 +49,7 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint(
write_thread_count: int = 1,
initial_write_buffer_size_bytes: int = DEFAULT_INITIAL_BUFFER_SIZE_BYTES,
use_optimized_save: bool = True,
use_fully_parallel_wrapper: bool = False,
) -> MLFlashpointAutoResume:
"""Wraps the trainer and creates an MLFlashpointAutoResume instance wrapping `default_auto_resume`.

Expand All @@ -62,6 +67,8 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint(
write_thread_count: Optional. The number of threads to use for writing checkpoint data. Defaults to 1.
initial_write_buffer_size_bytes: Optional. The initial size of the buffer for writing checkpoint data
in bytes. Defaults to `DEFAULT_INITIAL_BUFFER_SIZE_BYTES`.
use_fully_parallel_wrapper: Whether to wrap save/load strategies with `FullyParallel...Wrapper`.
Defaults to `False`.
Returns:
An MLFlashpointAutoResume instance configured for ML Flashpoint, wrapping `default_auto_resume`.
"""
Expand Down Expand Up @@ -90,6 +97,7 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint(
write_thread_count=write_thread_count,
initial_write_buffer_size_bytes=initial_write_buffer_size_bytes,
use_optimized_save=use_optimized_save,
use_fully_parallel_wrapper=use_fully_parallel_wrapper,
)

default_auto_resume_args = vars(default_auto_resume) if default_auto_resume else {}
Expand All @@ -111,6 +119,7 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint(
write_thread_count: int = 1,
initial_write_buffer_size_bytes: int = DEFAULT_INITIAL_BUFFER_SIZE_BYTES,
use_optimized_save: bool = True,
use_fully_parallel_wrapper: bool = False,
):
"""Wraps the trainer's checkpoint I/O with ML Flashpoint capabilities.

Expand Down Expand Up @@ -138,6 +147,8 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint(
write_thread_count: Optional. The number of threads to use for writing checkpoint data. Defaults to 1.
initial_write_buffer_size_bytes: Optional. The initial size of the buffer for writing checkpoint data
in bytes. Defaults to `DEFAULT_INITIAL_BUFFER_SIZE_BYTES`.
use_fully_parallel_wrapper: Whether to wrap save/load strategies with `FullyParallel...Wrapper`.
Defaults to `False`.

Returns:
None. The trainer's checkpoint_io is modified in-place.
Expand Down Expand Up @@ -217,6 +228,10 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint(
checkpoint_loader=checkpoint_loader,
)

if use_fully_parallel_wrapper:
save_strategy = FullyParallelSaveStrategyWrapper(save_strategy)
load_strategy = FullyParallelLoadStrategyWrapper(load_strategy)

ml_flashpoint_checkpoint_io = MLFlashpointCheckpointIO(
flashpoint_base_path=flashpoint_base_container,
alt_checkpoint_io=checkpoint_io,
Expand Down
1 change: 1 addition & 0 deletions src/ml_flashpoint/checkpoint_object_manager/buffer_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Union

from ml_flashpoint.checkpoint_object_manager.buffer_object.buffer_object_ext import BufferObject

from ml_flashpoint.core.mlf_logging import get_logger

from .buffer_metadata import METADATA_SIZE, BufferMetadataType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
import shutil
from typing import Optional

from ml_flashpoint.checkpoint_object_manager.buffer_object.buffer_object_ext import BufferObject

from ml_flashpoint.checkpoint_object_manager.buffer_io import BufferIO
from ml_flashpoint.checkpoint_object_manager.buffer_metadata import METADATA_SIZE
from ml_flashpoint.checkpoint_object_manager.buffer_object.buffer_object_ext import BufferObject
from ml_flashpoint.core.checkpoint_id_types import CheckpointContainerId, CheckpointObjectId
from ml_flashpoint.core.mlf_logging import get_logger

Expand Down
75 changes: 75 additions & 0 deletions tests/adapter/nemo/test_wrapper_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
import dataclasses

import pytest
from megatron.core.dist_checkpointing.strategies.fully_parallel import (
FullyParallelLoadStrategyWrapper,
FullyParallelSaveStrategyWrapper,
)
from nemo import lightning as nl
from nemo.lightning.io.pl import MegatronCheckpointIO
from nemo.lightning.pytorch import strategies as nl_strategies
Expand Down Expand Up @@ -103,6 +107,7 @@ def test_successful_wrap_and_resume_creation(self, mocker):
write_thread_count=1,
initial_write_buffer_size_bytes=DEFAULT_INITIAL_BUFFER_SIZE_BYTES,
use_optimized_save=True,
use_fully_parallel_wrapper=False,
)

# 3. Result is correct type and has correct attributes
Expand Down Expand Up @@ -489,6 +494,76 @@ def test_successful_wrapping_no_async_wrapper(self, mocker, mock_ckpt_obj_manage
assert trainer.strategy.checkpoint_io.fallback_checkpoint_io is original_checkpoint_io
assert trainer.strategy.checkpoint_io.async_save is True

def test_fully_parallel_wrapper_enabled(self, mocker, mock_ckpt_obj_manager, mock_replication_manager):
"""Tests that FullyParallel wrappers are applied when flag=True."""

# Given
trainer = mocker.MagicMock(spec=nl_trainer.Trainer)
trainer.callbacks = [mocker.MagicMock(spec=MLFlashpointCheckpointCallback)]
trainer.strategy = mocker.MagicMock(spec=nl_strategies.MegatronStrategy)
original_checkpoint_io = mocker.MagicMock(spec=MegatronCheckpointIO)
trainer.strategy.checkpoint_io = original_checkpoint_io
base_container = "/test_base_container"

# When
wrap_trainer_checkpoint_io_with_mlflashpoint(
trainer,
base_container,
mock_ckpt_obj_manager,
mock_replication_manager,
async_save=True,
checkpoint_loader=mocker.MagicMock(spec=DefaultMLFlashpointCheckpointLoader),
use_fully_parallel_wrapper=True, # 🔥 enable it
)

# Then
wrapped_io = trainer.strategy.checkpoint_io
assert isinstance(wrapped_io, MLFlashpointCheckpointIO)

assert isinstance(
wrapped_io.save_strategy,
FullyParallelSaveStrategyWrapper,
)
assert isinstance(
wrapped_io.load_strategy,
FullyParallelLoadStrategyWrapper,
)

def test_fully_parallel_wrapper_disabled_by_default(self, mocker, mock_ckpt_obj_manager, mock_replication_manager):
"""Tests that FullyParallel wrappers are NOT applied when flag=False."""

# Given
trainer = mocker.MagicMock(spec=nl_trainer.Trainer)
trainer.callbacks = [mocker.MagicMock(spec=MLFlashpointCheckpointCallback)]
trainer.strategy = mocker.MagicMock(spec=nl_strategies.MegatronStrategy)
original_checkpoint_io = mocker.MagicMock(spec=MegatronCheckpointIO)
trainer.strategy.checkpoint_io = original_checkpoint_io
base_container = "/test_base_container"

# When
wrap_trainer_checkpoint_io_with_mlflashpoint(
trainer,
base_container,
mock_ckpt_obj_manager,
mock_replication_manager,
async_save=True,
checkpoint_loader=mocker.MagicMock(spec=DefaultMLFlashpointCheckpointLoader),
use_fully_parallel_wrapper=False, # default behavior
)

# Then
wrapped_io = trainer.strategy.checkpoint_io
assert isinstance(wrapped_io, MLFlashpointCheckpointIO)

assert not isinstance(
wrapped_io.save_strategy,
FullyParallelSaveStrategyWrapper,
)
assert not isinstance(
wrapped_io.load_strategy,
FullyParallelLoadStrategyWrapper,
)

def test_successful_wrapping_with_async_wrapper(self, mocker, mock_ckpt_obj_manager, mock_replication_manager):
"""Tests successful wrapping when an async wrapper is present."""
# Given
Expand Down
2 changes: 1 addition & 1 deletion tests/checkpoint_object_manager/test_buffer_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
import tempfile

import pytest
from ml_flashpoint.checkpoint_object_manager.buffer_object.buffer_object_ext import BufferObject

from ml_flashpoint.checkpoint_object_manager.buffer_io import METADATA_SIZE, BufferIO
from ml_flashpoint.checkpoint_object_manager.buffer_metadata import BufferMetadataType
from ml_flashpoint.checkpoint_object_manager.buffer_object.buffer_object_ext import BufferObject
from ml_flashpoint.core.defaults import CheckpointFormat


Expand Down
Loading