Skip to content

Commit 8b9c0a3

Browse files
Leahlijuankkkapug-husam
authored
perf(adapter/megatron): Change the MLF AsyncCallsQueue to persistent (#42)
- Enable pickling replication_manager while keep transfer_service. - Enable pickling memory_storage_writer. - Inject rank and step to AsyncRequest. Part of #40 - [ ] Tests pass - [ ] Appropriate changes to documentation are included in the PR After the changes, the average of step max for write_data is 3.3280s for llama 8B on 2 A3-mega machines. --------- Co-authored-by: kkkapu <110692213+kkkapu@users.noreply.github.com> Co-authored-by: g-husam <husameldawi@google.com>
1 parent d15e7bb commit 8b9c0a3

File tree

10 files changed

+325
-36
lines changed

10 files changed

+325
-36
lines changed

src/ml_flashpoint/adapter/megatron/save_strategies.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pathlib import Path
2020
from typing import Union
2121

22+
import torch
2223
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
2324
from megatron.core.dist_checkpointing.strategies.async_utils import AsyncRequest
2425
from megatron.core.dist_checkpointing.strategies.base import AsyncSaveShardedStrategy
@@ -32,7 +33,7 @@
3233

3334
from ml_flashpoint.adapter.pytorch import custom_state_dict_saver as statedictsaver
3435
from ml_flashpoint.adapter.pytorch.memory_storage_writer import MemoryStorageWriter
35-
from ml_flashpoint.core import utils
36+
from ml_flashpoint.core import mlf_logging, utils
3637
from ml_flashpoint.core.checkpoint_id_types import CheckpointContainerId
3738
from ml_flashpoint.core.checkpoint_saver import MLFlashpointCheckpointSaver, ObjectWriteBucket
3839
from ml_flashpoint.core.mlf_logging import get_logger
@@ -41,6 +42,26 @@
4142
_LOGGER = get_logger(__name__)
4243

4344

45+
def _save_checkpoint(
46+
staged_buckets: list[ObjectWriteBucket],
47+
checkpoint_id: CheckpointContainerId,
48+
storage_writer: MemoryStorageWriter,
49+
rank: int,
50+
step: int,
51+
):
52+
"""
53+
This function is the 'async_fn' run in Megatron's :class:`AsyncRequest`.
54+
"""
55+
56+
mlf_logging.setup_worker_logging(rank, step)
57+
statedictsaver.write_data(
58+
checkpoint_id=checkpoint_id,
59+
storage_writer=storage_writer,
60+
staged_write_buckets=staged_buckets,
61+
replicate_after_write=False,
62+
)
63+
64+
4465
def default_backend_format_name() -> str:
4566
return "ml_flashpoint"
4667

@@ -105,7 +126,7 @@ def async_save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Union
105126
# 1b. Re-initialize the StorageWriter to use a new instance per save to avoid hangs from shared state.
106127
self._storage_writer = MemoryStorageWriter(
107128
checkpoint_saver=self._checkpoint_saver,
108-
mp_manager=self._storage_writer._mp_manager,
129+
mp_manager=self._storage_writer._main_process_torchmp_manager,
109130
thread_count=self._storage_writer._thread_count,
110131
)
111132
# 1c. Reset the StorageWriter for this checkpoint version.
@@ -156,17 +177,6 @@ def async_save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Union
156177
with open(os.path.join(checkpoint_dir, "metadata.json"), "w") as f:
157178
json.dump(metadata, f)
158179

159-
def _save_checkpoint(staged_buckets: list[ObjectWriteBucket]):
160-
"""
161-
This function is the 'async_fn' run in Megatron's :class:`AsyncRequest`.
162-
"""
163-
statedictsaver.write_data(
164-
checkpoint_id=checkpoint_id,
165-
storage_writer=self._storage_writer,
166-
staged_write_buckets=staged_buckets,
167-
replicate_after_write=False,
168-
)
169-
170180
finalize_fns = [
171181
# Replicate written objects
172182
partial(
@@ -188,9 +198,18 @@ def _save_checkpoint(staged_buckets: list[ObjectWriteBucket]):
188198
),
189199
]
190200

201+
current_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else -1
202+
current_step = mlf_logging.get_current_step()
203+
191204
return AsyncRequest(
192205
async_fn=_save_checkpoint,
193206
async_fn_args=(),
194-
async_fn_kwargs={"staged_buckets": staged_write_buckets},
207+
async_fn_kwargs={
208+
"staged_buckets": staged_write_buckets,
209+
"checkpoint_id": checkpoint_id,
210+
"storage_writer": self._storage_writer,
211+
"rank": current_rank,
212+
"step": current_step,
213+
},
195214
finalize_fns=finalize_fns,
196215
)

src/ml_flashpoint/adapter/nemo/checkpoint_io.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def __init__(self, checkpoint_io: AsyncCompatibleCheckpointIO):
323323
raise ValueError("Incompatible wrapped checkpoint_io type: %s", type(checkpoint_io))
324324

325325
super().__init__(checkpoint_io)
326-
self._mlf_async_calls_queue = AsyncCallsQueue()
326+
self._mlf_async_calls_queue = AsyncCallsQueue(persistent=True)
327327
self._alt_async_calls_queue = AsyncCallsQueue()
328328

329329
@property
@@ -420,3 +420,15 @@ def teardown(self) -> None:
420420
):
421421
# Can't do finalization now because some ranks might be lost
422422
_LOGGER.warning("Some async checkpoint saves might be not finalized properly.")
423+
424+
if hasattr(self._mlf_async_calls_queue, "close"):
425+
self._mlf_async_calls_queue.close()
426+
# Monkeypatch persistent caller's close method to prevent double-close error at exit
427+
# which happens if __del__ is called after process group destruction.
428+
# We access the caller directly if possible as AsyncCallsQueue might store it as 'persistent_caller'.
429+
caller = getattr(self._mlf_async_calls_queue, "persistent_caller", None)
430+
if caller and hasattr(caller, "close"):
431+
# We already closed the queue (and hopefully the caller), so we prevent future closes.
432+
# Specifically, PersistentAsyncCaller.__del__ calls close() which calls torch.distributed.get_rank(),
433+
# causing a crash if the process group is already destroyed.
434+
caller.close = lambda: None

src/ml_flashpoint/adapter/nemo/wrapper_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint(
202202
checkpoint_saver=DefaultMLFlashpointCheckpointSaver(
203203
global_rank_getter=torch.distributed.get_rank,
204204
local_rank_getter=torch.distributed.get_node_local_rank,
205-
global_barrier_func=lambda: torch.distributed.barrier(),
205+
global_barrier_func=torch.distributed.barrier,
206206
ckpt_obj_manager=ckpt_obj_manager,
207207
replication_manager=replication_manager,
208208
initial_buffer_size_bytes=initial_write_buffer_size_bytes,

src/ml_flashpoint/adapter/pytorch/memory_storage_writer.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,24 @@ def __init__(
109109
_LOGGER.warning("thread_count must be >= 1, but was %d. Setting to 1.", thread_count)
110110
thread_count = 1
111111
self._thread_count = thread_count
112-
self._mp_manager = mp_manager
112+
# _main_process_torchmp_manager should only be used in the main process, not in the spawned processes.
113+
# This is because mp_manager is not picklable.
114+
self._main_process_torchmp_manager = mp_manager
113115
self._write_events_per_checkpoint_id: dict[CheckpointContainerId, torch_mp.Event] = mp_manager.dict()
114116
self._write_results_per_checkpoint_id: dict[CheckpointContainerId, list[WriteResult]] = mp_manager.dict()
115117

118+
def __getstate__(self):
119+
"""Custom pickling to exclude unpicklable mp_manager."""
120+
state = self.__dict__.copy()
121+
if "_main_process_torchmp_manager" in state:
122+
del state["_main_process_torchmp_manager"]
123+
return state
124+
125+
def __setstate__(self, state):
126+
"""Custom unpickling to restore state and set mp_manager to None."""
127+
self.__dict__.update(state)
128+
self._main_process_torchmp_manager = None
129+
116130
def _check_checkpoint_id(self) -> None:
117131
if self._current_checkpoint_id is None:
118132
raise ValueError("MemoryStorageWriter has not been reset. Call reset() before using this method.")
@@ -177,7 +191,7 @@ def prepare_write_data_buckets(
177191
) -> list[ObjectWriteBucket]:
178192
# Create a new, unset Event for this specific checkpoint save
179193
if checkpoint_id not in self._write_events_per_checkpoint_id:
180-
self._write_events_per_checkpoint_id[checkpoint_id] = self._mp_manager.Event()
194+
self._write_events_per_checkpoint_id[checkpoint_id] = self._main_process_torchmp_manager.Event()
181195

182196
write_buckets = self.checkpoint_saver.prepare_write_data(
183197
checkpoint_id, plan.items, planner, plan.storage_data.prefix, bucket_count=self._thread_count

src/ml_flashpoint/core/checkpoint_saver.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,21 @@ def __init__(
320320
self._initial_buffer_size_bytes = initial_buffer_size_bytes
321321
self._use_optimized_save = use_optimized_save
322322

323+
def __getstate__(self):
324+
"""Custom pickling to exclude _replication_manager."""
325+
state = self.__dict__.copy()
326+
# Exclude _replication_manager from the pickled state as it is not needed in workers
327+
# and may be unpickleable or expensive to transfer.
328+
if "_replication_manager" in state:
329+
del state["_replication_manager"]
330+
return state
331+
332+
def __setstate__(self, state):
333+
"""Custom unpickling to restore state and set _replication_manager to None."""
334+
self.__dict__.update(state)
335+
# Restore _replication_manager as None in the worker process
336+
self._replication_manager = None
337+
323338
@override
324339
@log_execution_time(logger=_LOGGER, name="initialize_checkpoint")
325340
def initialize_checkpoint(self, checkpoint_id: CheckpointContainerId) -> None:
@@ -506,6 +521,11 @@ def write_data(
506521

507522
@log_execution_time(logger=_LOGGER, name="async_replicate_object")
508523
def async_replicate_object(self, object_id: CheckpointObjectId) -> list[concurrent.futures.Future]:
524+
if self._replication_manager is None:
525+
# This can happen in worker processes where we don't pickle the manager.
526+
# If this is called, it means replicate_after_write=True was passed erroneously or
527+
# the strategy is trying to replicate in a worker where it shouldn't.
528+
raise RuntimeError("ReplicationManager is not available (None). Cannot replicate object.")
509529
object_buffer_io = self._chkpt_obj_manager.get_buffer(object_id)
510530
return self._replication_manager.async_replicate(object_buffer_io)
511531

src/ml_flashpoint/core/mlf_logging.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
# -1 is the default sentinel (invalid) value.
3131
_TRAINING_STEP = multiprocessing.Value("i", _MISSING_NONNEG_NUMERIC_VAL)
3232

33+
_STATIC_RANK = _MISSING_NONNEG_NUMERIC_VAL
34+
3335

3436
def update_training_step(new_val: int):
3537
"""Updates the global training step value used in logs.
@@ -42,6 +44,23 @@ def update_training_step(new_val: int):
4244
_TRAINING_STEP.value = new_val
4345

4446

47+
def get_current_step() -> int:
48+
"""Returns the current training step."""
49+
return _TRAINING_STEP.value
50+
51+
52+
def setup_worker_logging(rank: int, step: int):
53+
"""Sets up logging context for a worker process.
54+
55+
Args:
56+
rank: The rank to log.
57+
step: The step to log.
58+
"""
59+
global _STATIC_RANK
60+
_STATIC_RANK = rank
61+
update_training_step(step)
62+
63+
4564
class TrainingContextFormatter(logging.Formatter):
4665
"""A logging formatter that includes useful contextual information in the log records."""
4766

@@ -55,7 +74,12 @@ def format(self, record):
5574
Returns:
5675
The formatted log record as a string.
5776
"""
58-
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else _MISSING_NONNEG_NUMERIC_VAL
77+
if _STATIC_RANK != _MISSING_NONNEG_NUMERIC_VAL:
78+
rank = _STATIC_RANK
79+
elif torch.distributed.is_initialized():
80+
rank = torch.distributed.get_rank()
81+
else:
82+
rank = _MISSING_NONNEG_NUMERIC_VAL
5983
record.rank = rank
6084
step_val = _TRAINING_STEP.value
6185
record.curr_step = step_val

tests/adapter/megatron/test_save_strategies.py

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def test_async_save_initialization_calls_success(
189189

190190
mock_memory_storage_writer_cls.assert_called_once_with(
191191
checkpoint_saver=checkpoint_saver,
192-
mp_manager=storage_writer._mp_manager,
192+
mp_manager=storage_writer._main_process_torchmp_manager,
193193
thread_count=storage_writer._thread_count,
194194
)
195195
mock_new_storage_writer_instance.reset.assert_called_once_with(checkpoint_id.data)
@@ -229,7 +229,7 @@ def test_async_save_reinitializes_storage_writer_with_thread_count(
229229
# Then
230230
mock_memory_storage_writer_cls.assert_called_once_with(
231231
checkpoint_saver=checkpoint_saver,
232-
mp_manager=storage_writer._mp_manager,
232+
mp_manager=storage_writer._main_process_torchmp_manager,
233233
thread_count=expected_thread_count,
234234
)
235235

@@ -275,7 +275,9 @@ def test_async_save_generate_plan_call_success(self, mocker, async_save_setup, s
275275
assert kwargs["state_dict"] == pyt_state_dict
276276
assert actual_storage_writer_used is not None
277277
assert isinstance(actual_storage_writer_used, MemoryStorageWriter)
278-
assert actual_storage_writer_used._mp_manager is storage_writer._mp_manager
278+
assert (
279+
actual_storage_writer_used._main_process_torchmp_manager is storage_writer._main_process_torchmp_manager
280+
)
279281
assert kwargs["planner"] is mock_planner
280282
assert "world_dist_wrapper" in kwargs
281283
assert kwargs["world_dist_wrapper"].use_dist is False
@@ -372,8 +374,8 @@ def test_async_save_finalize_fns_calls(
372374
"ml_flashpoint.adapter.megatron.save_strategies.MemoryStorageWriter"
373375
)
374376
mock_storage_writer_instance = mock_memory_storage_writer_cls.return_value
375-
# We need to set _mp_manager on the mock because the test asserts on it later
376-
mock_storage_writer_instance._mp_manager = storage_writer._mp_manager
377+
# We need to set _main_process_torchmp_manager on the mock because the test asserts on it later
378+
mock_storage_writer_instance._main_process_torchmp_manager = storage_writer._main_process_torchmp_manager
377379
mock_storage_writer_instance.stage_write_data_buckets.return_value = dummy_write_buckets
378380

379381
expected_kwarg_keys = {"checkpoint_id", "storage_writer", "global_metadata", "world_dist_wrapper"}
@@ -405,7 +407,9 @@ def test_async_save_finalize_fns_calls(
405407
assert kwargs["checkpoint_id"] == checkpoint_id
406408
assert actual_storage_writer_used is not None
407409
assert actual_storage_writer_used is mock_storage_writer_instance
408-
assert actual_storage_writer_used._mp_manager is storage_writer._mp_manager
410+
assert (
411+
actual_storage_writer_used._main_process_torchmp_manager is storage_writer._main_process_torchmp_manager
412+
)
409413
assert kwargs["global_metadata"] == dummy_metadata
410414
assert kwargs["world_dist_wrapper"].use_dist is False
411415

@@ -433,3 +437,41 @@ def test_finalize_fns_failure(
433437

434438
# Then
435439
finalize_checkpoint_spy.assert_not_called()
440+
441+
@pytest.mark.parametrize(
442+
"is_dist_initialized, dist_rank, expected_rank",
443+
[
444+
(True, 5, 5),
445+
(False, 0, -1),
446+
],
447+
)
448+
def test_async_save_rank_determination(
449+
self,
450+
mocker,
451+
async_save_setup,
452+
is_dist_initialized,
453+
dist_rank,
454+
expected_rank,
455+
):
456+
"""Tests that the rank passed to async_fn is correct based on dist initialization."""
457+
# Given
458+
strategy, checkpoint_id, sharded_state_dict, _ = async_save_setup
459+
460+
# Mock torch.distributed
461+
mocker.patch("torch.distributed.is_initialized", return_value=is_dist_initialized)
462+
if is_dist_initialized:
463+
mocker.patch("torch.distributed.get_rank", return_value=dist_rank)
464+
465+
# Mock dependencies to ensure success path
466+
mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver")
467+
mock_statedictsaver.generate_plan.return_value = (
468+
mocker.MagicMock(),
469+
mocker.MagicMock(),
470+
mocker.MagicMock(),
471+
)
472+
473+
# When
474+
actual_async_request = strategy.async_save(sharded_state_dict, checkpoint_id.data)
475+
476+
# Then
477+
assert actual_async_request.async_fn_kwargs["rank"] == expected_rank

0 commit comments

Comments
 (0)