Skip to content

Commit f25fca6

Browse files
ronaldw07ronaldwen07gemini-code-assist[bot]
authored
refactor(adapter/megatron): extract local-aware megatron save into helper function (#43)
## Description Extract the local-aware Megatron distributed save logic from `MLFlashpointCheckpointIO.save_checkpoint()` into a reusable helper function `save_local_aware_megatron_checkpoint()`. This change allows users implementing Megatron checkpointing to easily adopt the local-aware pattern without copying inline code from the implementation. ## Changes - Created new `src/ml_flashpoint/adapter/megatron/utils.py` with `save_local_aware_megatron_checkpoint()` helper function - Refactored `MLFlashpointCheckpointIO.save_checkpoint()` to use the new helper - Updated `docs/user-guide.md` to reference the helper function instead of pointing to implementation details - Exported the helper from `ml_flashpoint.adapter.megatron` module ## Type of Change - [x] Refactoring - [ ] Bug fix - [ ] New feature - [ ] Performance improvement - [ ] Documentation update ## Testing - [x] Code passes ruff linting - [ ] Tests pass locally Closes #29 --------- Co-authored-by: ronaldwen07 <ronaldwen07@users.noreply.github.com> Co-authored-by: ron <138569343+ronaldwen07@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 5f49b38 commit f25fca6

File tree

5 files changed

+134
-61
lines changed

5 files changed

+134
-61
lines changed

docs/user-guide.md

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ from ml_flashpoint.replication.replication_manager import ReplicationManager
133133

134134
# Megatron Checkpointing
135135
from megatron.core import dist_checkpointing as mcore_dist_checkpointing
136+
from ml_flashpoint.adapter.megatron.save_utils import save_local_aware_megatron_checkpoint
136137
```
137138

138139
#### Save Strategy
@@ -150,8 +151,19 @@ megatron_save_strategy = MLFlashpointMegatronAsyncSaveStrategy(
150151
)
151152
```
152153

153-
Because Megatron's `dist_checkpointing.save()` function writes "common" data only on global rank 0, which does not align with local checkpointing, you can orchestrate saves using the save strategy the same way it's done in [`MLFlashpointCheckpointIO.save_checkpoint()`](https://github.com/google/ml-flashpoint/blob/b9767583520106f59743b9e8050769523cfbef6e/src/ml_flashpoint/adapter/nemo/checkpoint_io.py#L137-L171) in the `ml_flashpoint.adapter.nemo` package.
154-
You'll notice that the logic there aims to mimic `dist_checkpointing.save`, but it saves common data on each node (via local rank 0) as opposed to solely on the coordinator node (global rank 0).
154+
Because Megatron's `dist_checkpointing.save()` function writes "common" data only on global rank 0, which does not align with local checkpointing, use the provided helper function `save_local_aware_megatron_checkpoint()` from the `ml_flashpoint.adapter.megatron.save_utils` module.
155+
156+
This helper mimics `dist_checkpointing.save()`, but saves common data on each node (via local rank 0) rather than solely on the coordinator node (global rank 0).
157+
158+
```python
159+
# In your save loop
160+
async_request = save_local_aware_megatron_checkpoint(
161+
checkpoint=state_dict,
162+
checkpoint_dir=str(curr_step_checkpoint_id),
163+
save_strategy=megatron_save_strategy,
164+
async_save=True,
165+
)
166+
```
155167

156168
!!! note
157169

src/ml_flashpoint/adapter/megatron/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from ml_flashpoint.adapter.megatron.save_utils import (
16+
save_local_aware_megatron_checkpoint as save_local_aware_megatron_checkpoint,
17+
)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
from pathlib import Path
17+
from typing import Any, Optional, Union
18+
19+
import torch
20+
from megatron.core.dist_checkpointing import state_dict_utils as mcore_state_dict_utils
21+
from megatron.core.dist_checkpointing.strategies.async_utils import AsyncRequest
22+
from megatron.core.dist_checkpointing.strategies.common import COMMON_STATE_FNAME
23+
24+
from ml_flashpoint.core.mlf_logging import get_logger
25+
26+
_LOGGER = get_logger(__name__)
27+
28+
29+
def save_local_aware_megatron_checkpoint(
30+
checkpoint: dict[str, Any],
31+
checkpoint_dir: Union[str, Path],
32+
save_strategy,
33+
async_save: bool = True,
34+
) -> Optional[AsyncRequest]:
35+
"""Saves a checkpoint with local-aware common state handling.
36+
37+
This function mimics the CommonStrategy logic from Megatron's dist_checkpointing.save(),
38+
but with a key difference: it saves common data on each node (via local rank 0)
39+
rather than solely on the coordinator node (global rank 0).
40+
41+
This is necessary for local checkpointing where each node needs its own copy
42+
of the common state for fast recovery.
43+
44+
Args:
45+
checkpoint: The checkpoint dictionary to save.
46+
checkpoint_dir: The directory path to save the checkpoint to.
47+
save_strategy: The save strategy instance with async_save() and save() methods.
48+
Typically MLFlashpointMegatronAsyncSaveStrategy.
49+
async_save: Whether to save asynchronously. Defaults to True.
50+
51+
Returns:
52+
An AsyncRequest if async_save is True and save succeeds, None otherwise.
53+
Returns None on save failure (exception is logged).
54+
"""
55+
# Split common and sharded state
56+
sharded_state_dict, common_state_dict = mcore_state_dict_utils.save_preprocess(checkpoint)
57+
58+
# Save common state on each node (local rank 0)
59+
if torch.distributed.get_node_local_rank() == 0:
60+
_LOGGER.debug("Saving common_state_dict...")
61+
os.makedirs(checkpoint_dir, exist_ok=True)
62+
torch.save(common_state_dict, os.path.join(checkpoint_dir, COMMON_STATE_FNAME))
63+
64+
# Execute save strategy
65+
try:
66+
if async_save:
67+
return save_strategy.async_save(
68+
sharded_state_dict=sharded_state_dict,
69+
checkpoint_dir=checkpoint_dir,
70+
)
71+
else:
72+
save_strategy.save(
73+
sharded_state_dict=sharded_state_dict,
74+
checkpoint_dir=checkpoint_dir,
75+
)
76+
return None
77+
except Exception:
78+
_LOGGER.exception("Failed to save ML Flashpoint checkpoint. Skipping saving and continuing.")
79+
return None

src/ml_flashpoint/adapter/nemo/checkpoint_io.py

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,13 @@
2020
import torch
2121
from lightning.fabric.utilities.types import _PATH
2222
from megatron.core import dist_checkpointing as mcore_dist_checkpointing
23-
from megatron.core.dist_checkpointing import state_dict_utils as mcore_state_dict_utils
2423
from megatron.core.dist_checkpointing.strategies.async_utils import (
2524
AsyncCallsQueue,
2625
)
2726
from megatron.core.dist_checkpointing.strategies.async_utils import (
2827
AsyncRequest as MegatronAsyncRequest,
2928
)
30-
from megatron.core.dist_checkpointing.strategies.common import COMMON_STATE_FNAME, TorchCommonLoadStrategy
29+
from megatron.core.dist_checkpointing.strategies.common import TorchCommonLoadStrategy
3130
from nemo.lightning.io.pl import MegatronCheckpointIO, TrainerContext, _fix_tensors_device
3231
from nemo.lightning.pytorch.trainer import Trainer
3332
from nemo.utils.callbacks.dist_ckpt_io import AsyncCompatibleCheckpointIO, AsyncFinalizableCheckpointIO
@@ -39,6 +38,7 @@
3938
from ml_flashpoint.adapter.megatron.save_strategies import (
4039
MLFlashpointMegatronAsyncSaveStrategy,
4140
)
41+
from ml_flashpoint.adapter.megatron.save_utils import save_local_aware_megatron_checkpoint
4242
from ml_flashpoint.checkpoint_object_manager.checkpoint_object_manager import (
4343
CheckpointObjectManager,
4444
)
@@ -134,41 +134,20 @@ def save_checkpoint(
134134
return self.fallback_checkpoint_io.save_checkpoint(checkpoint, path)
135135
_LOGGER.info("Use ML Flashpoint checkpoint io. Async_save: '%s'", self.async_save)
136136

137-
# Mimic the CommonStrategy logic, on each rank.
138-
# We split the "common" data from the "sharded" data, write the common data to a specific file on each rank,
139-
# and continue with checkpointing on the "sharded" data.
140-
# We do this explicitly here rather than using the mcore_dist_checkpointing.save API directly because that
141-
# has rank-specific logic that writes common data only on global rank 0, and we want to write it on all ranks.
142-
# Other than that, this logic should mimic `megatron.core.dist_checkpointing.save`.
143-
sharded_state_dict, common_state_dict = mcore_state_dict_utils.save_preprocess(checkpoint)
144-
145-
if torch.distributed.get_node_local_rank() == 0:
146-
# Since we are writing the common state directly here before executing the save orchestration,
147-
# we need to ensure the parent checkpoint dir exists.
148-
_LOGGER.debug("Saving common_state_dict...")
149-
os.makedirs(path, exist_ok=True)
150-
torch.save(common_state_dict, os.path.join(path, COMMON_STATE_FNAME))
137+
# Use the helper for local-aware megatron save
138+
optional_async_request = save_local_aware_megatron_checkpoint(
139+
checkpoint=checkpoint,
140+
checkpoint_dir=path,
141+
save_strategy=self.save_strategy,
142+
async_save=self.async_save,
143+
)
144+
145+
# Handle optional context save (only if enabled)
151146
if self.always_save_context:
152147
_LOGGER.debug("Saving context...")
153148
self._save_context(path)
154149

155-
try:
156-
if self.async_save:
157-
async_request = self.save_strategy.async_save(
158-
sharded_state_dict=sharded_state_dict,
159-
checkpoint_dir=path,
160-
)
161-
return async_request
162-
else:
163-
# For sync save, no AsyncRequest is needed, so returning None.
164-
self.save_strategy.save(
165-
sharded_state_dict=sharded_state_dict,
166-
checkpoint_dir=path,
167-
)
168-
return None
169-
except Exception:
170-
_LOGGER.exception("Failed to save ML Flashpoint checkpoint. Skipping saving and continuing.")
171-
return None
150+
return optional_async_request
172151

173152
@log_execution_time(logger=_LOGGER, name="MLFlashpointCheckpointIO._save_context", level=logging.INFO)
174153
def _save_context(self, path: _PATH) -> Optional[threading.Thread]:

0 commit comments

Comments
 (0)