diff --git a/src/ml_flashpoint/adapter/nemo/wrapper_util.py b/src/ml_flashpoint/adapter/nemo/wrapper_util.py index 1f33501..d315e2e 100644 --- a/src/ml_flashpoint/adapter/nemo/wrapper_util.py +++ b/src/ml_flashpoint/adapter/nemo/wrapper_util.py @@ -15,6 +15,10 @@ from typing import Union import torch +from megatron.core.dist_checkpointing.strategies.fully_parallel import ( + FullyParallelLoadStrategyWrapper, + FullyParallelSaveStrategyWrapper, +) from nemo import lightning as nl from nemo.lightning.io.pl import MegatronCheckpointIO from nemo.lightning.pytorch import strategies as nl_strategies @@ -45,6 +49,7 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint( write_thread_count: int = 1, initial_write_buffer_size_bytes: int = DEFAULT_INITIAL_BUFFER_SIZE_BYTES, use_optimized_save: bool = True, + use_fully_parallel_wrapper: bool = False, ) -> MLFlashpointAutoResume: """Wraps the trainer and creates an MLFlashpointAutoResume instance wrapping `default_auto_resume`. @@ -62,6 +67,8 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint( write_thread_count: Optional. The number of threads to use for writing checkpoint data. Defaults to 1. initial_write_buffer_size_bytes: Optional. The initial size of the buffer for writing checkpoint data in bytes. Defaults to `DEFAULT_INITIAL_BUFFER_SIZE_BYTES`. + use_fully_parallel_wrapper: Whether to wrap save/load strategies with `FullyParallel...Wrapper`. + Defaults to `False`. Returns: An MLFlashpointAutoResume instance configured for ML Flashpoint, wrapping `default_auto_resume`. """ @@ -90,6 +97,7 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint( write_thread_count=write_thread_count, initial_write_buffer_size_bytes=initial_write_buffer_size_bytes, use_optimized_save=use_optimized_save, + use_fully_parallel_wrapper=use_fully_parallel_wrapper, ) default_auto_resume_args = vars(default_auto_resume) if default_auto_resume else {} @@ -111,6 +119,7 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint( write_thread_count: int = 1, initial_write_buffer_size_bytes: int = DEFAULT_INITIAL_BUFFER_SIZE_BYTES, use_optimized_save: bool = True, + use_fully_parallel_wrapper: bool = False, ): """Wraps the trainer's checkpoint I/O with ML Flashpoint capabilities. @@ -138,6 +147,8 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint( write_thread_count: Optional. The number of threads to use for writing checkpoint data. Defaults to 1. initial_write_buffer_size_bytes: Optional. The initial size of the buffer for writing checkpoint data in bytes. Defaults to `DEFAULT_INITIAL_BUFFER_SIZE_BYTES`. + use_fully_parallel_wrapper: Whether to wrap save/load strategies with `FullyParallel...Wrapper`. + Defaults to `False`. Returns: None. The trainer's checkpoint_io is modified in-place. @@ -217,6 +228,10 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint( checkpoint_loader=checkpoint_loader, ) + if use_fully_parallel_wrapper: + save_strategy = FullyParallelSaveStrategyWrapper(save_strategy) + load_strategy = FullyParallelLoadStrategyWrapper(load_strategy) + ml_flashpoint_checkpoint_io = MLFlashpointCheckpointIO( flashpoint_base_path=flashpoint_base_container, alt_checkpoint_io=checkpoint_io, diff --git a/src/ml_flashpoint/checkpoint_object_manager/buffer_io.py b/src/ml_flashpoint/checkpoint_object_manager/buffer_io.py index 62ba13c..9fd1b27 100644 --- a/src/ml_flashpoint/checkpoint_object_manager/buffer_io.py +++ b/src/ml_flashpoint/checkpoint_object_manager/buffer_io.py @@ -16,6 +16,7 @@ from typing import Union from ml_flashpoint.checkpoint_object_manager.buffer_object.buffer_object_ext import BufferObject + from ml_flashpoint.core.mlf_logging import get_logger from .buffer_metadata import METADATA_SIZE, BufferMetadataType diff --git a/src/ml_flashpoint/checkpoint_object_manager/checkpoint_object_manager.py b/src/ml_flashpoint/checkpoint_object_manager/checkpoint_object_manager.py index 929bcb8..b38630e 100644 --- a/src/ml_flashpoint/checkpoint_object_manager/checkpoint_object_manager.py +++ b/src/ml_flashpoint/checkpoint_object_manager/checkpoint_object_manager.py @@ -16,9 +16,10 @@ import shutil from typing import Optional +from ml_flashpoint.checkpoint_object_manager.buffer_object.buffer_object_ext import BufferObject + from ml_flashpoint.checkpoint_object_manager.buffer_io import BufferIO from ml_flashpoint.checkpoint_object_manager.buffer_metadata import METADATA_SIZE -from ml_flashpoint.checkpoint_object_manager.buffer_object.buffer_object_ext import BufferObject from ml_flashpoint.core.checkpoint_id_types import CheckpointContainerId, CheckpointObjectId from ml_flashpoint.core.mlf_logging import get_logger diff --git a/tests/adapter/nemo/test_wrapper_util.py b/tests/adapter/nemo/test_wrapper_util.py index 3f9a4c7..d6841e8 100644 --- a/tests/adapter/nemo/test_wrapper_util.py +++ b/tests/adapter/nemo/test_wrapper_util.py @@ -17,6 +17,10 @@ import dataclasses import pytest +from megatron.core.dist_checkpointing.strategies.fully_parallel import ( + FullyParallelLoadStrategyWrapper, + FullyParallelSaveStrategyWrapper, +) from nemo import lightning as nl from nemo.lightning.io.pl import MegatronCheckpointIO from nemo.lightning.pytorch import strategies as nl_strategies @@ -103,6 +107,7 @@ def test_successful_wrap_and_resume_creation(self, mocker): write_thread_count=1, initial_write_buffer_size_bytes=DEFAULT_INITIAL_BUFFER_SIZE_BYTES, use_optimized_save=True, + use_fully_parallel_wrapper=False, ) # 3. Result is correct type and has correct attributes @@ -489,6 +494,76 @@ def test_successful_wrapping_no_async_wrapper(self, mocker, mock_ckpt_obj_manage assert trainer.strategy.checkpoint_io.fallback_checkpoint_io is original_checkpoint_io assert trainer.strategy.checkpoint_io.async_save is True + def test_fully_parallel_wrapper_enabled(self, mocker, mock_ckpt_obj_manager, mock_replication_manager): + """Tests that FullyParallel wrappers are applied when flag=True.""" + + # Given + trainer = mocker.MagicMock(spec=nl_trainer.Trainer) + trainer.callbacks = [mocker.MagicMock(spec=MLFlashpointCheckpointCallback)] + trainer.strategy = mocker.MagicMock(spec=nl_strategies.MegatronStrategy) + original_checkpoint_io = mocker.MagicMock(spec=MegatronCheckpointIO) + trainer.strategy.checkpoint_io = original_checkpoint_io + base_container = "/test_base_container" + + # When + wrap_trainer_checkpoint_io_with_mlflashpoint( + trainer, + base_container, + mock_ckpt_obj_manager, + mock_replication_manager, + async_save=True, + checkpoint_loader=mocker.MagicMock(spec=DefaultMLFlashpointCheckpointLoader), + use_fully_parallel_wrapper=True, # 🔥 enable it + ) + + # Then + wrapped_io = trainer.strategy.checkpoint_io + assert isinstance(wrapped_io, MLFlashpointCheckpointIO) + + assert isinstance( + wrapped_io.save_strategy, + FullyParallelSaveStrategyWrapper, + ) + assert isinstance( + wrapped_io.load_strategy, + FullyParallelLoadStrategyWrapper, + ) + + def test_fully_parallel_wrapper_disabled_by_default(self, mocker, mock_ckpt_obj_manager, mock_replication_manager): + """Tests that FullyParallel wrappers are NOT applied when flag=False.""" + + # Given + trainer = mocker.MagicMock(spec=nl_trainer.Trainer) + trainer.callbacks = [mocker.MagicMock(spec=MLFlashpointCheckpointCallback)] + trainer.strategy = mocker.MagicMock(spec=nl_strategies.MegatronStrategy) + original_checkpoint_io = mocker.MagicMock(spec=MegatronCheckpointIO) + trainer.strategy.checkpoint_io = original_checkpoint_io + base_container = "/test_base_container" + + # When + wrap_trainer_checkpoint_io_with_mlflashpoint( + trainer, + base_container, + mock_ckpt_obj_manager, + mock_replication_manager, + async_save=True, + checkpoint_loader=mocker.MagicMock(spec=DefaultMLFlashpointCheckpointLoader), + use_fully_parallel_wrapper=False, # default behavior + ) + + # Then + wrapped_io = trainer.strategy.checkpoint_io + assert isinstance(wrapped_io, MLFlashpointCheckpointIO) + + assert not isinstance( + wrapped_io.save_strategy, + FullyParallelSaveStrategyWrapper, + ) + assert not isinstance( + wrapped_io.load_strategy, + FullyParallelLoadStrategyWrapper, + ) + def test_successful_wrapping_with_async_wrapper(self, mocker, mock_ckpt_obj_manager, mock_replication_manager): """Tests successful wrapping when an async wrapper is present.""" # Given diff --git a/tests/checkpoint_object_manager/test_buffer_io.py b/tests/checkpoint_object_manager/test_buffer_io.py index 6d54043..a301771 100644 --- a/tests/checkpoint_object_manager/test_buffer_io.py +++ b/tests/checkpoint_object_manager/test_buffer_io.py @@ -20,10 +20,10 @@ import tempfile import pytest +from ml_flashpoint.checkpoint_object_manager.buffer_object.buffer_object_ext import BufferObject from ml_flashpoint.checkpoint_object_manager.buffer_io import METADATA_SIZE, BufferIO from ml_flashpoint.checkpoint_object_manager.buffer_metadata import BufferMetadataType -from ml_flashpoint.checkpoint_object_manager.buffer_object.buffer_object_ext import BufferObject from ml_flashpoint.core.defaults import CheckpointFormat