Skip to content

Commit 5179e62

Browse files
committed
docs(site): Update user-guide for Megatron-LM usage
1 parent b976758 commit 5179e62

File tree

4 files changed

+93
-9
lines changed

4 files changed

+93
-9
lines changed

docs/user-guide.md

Lines changed: 81 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ pip install -e ml-flashpoint
1515

1616
This assumes you are managing your own dependencies for PyTorch, Megatron-LM, NeMo, etc.
1717

18-
To install with adapter specific dependencies, specify the adapter of interest, such as `nemo`, `megatron`, `pytorch`:
18+
To install with this library's adapter-specific dependencies, specify the adapter of interest, such as `nemo`, `megatron`, `pytorch`:
1919

2020
```bash
2121
# Example for the NeMo adapter.
@@ -28,7 +28,7 @@ See the project's [README](http://cs/h/cloud-mlnet/ml-flashpoint/+/main:README.m
2828

2929
### NeMo 2.0 & Pytorch Lightning
3030

31-
Code: Check out the `adapter/nemo` package.
31+
Code: See the `ml_flashpoint.adapter.nemo` package.
3232

3333
!!! note
3434

@@ -55,7 +55,9 @@ mlflashpoint_base_path = _get_my_mlf_base_path()
5555
os.makedirs(mlflashpoint_base_path, exist_ok=True)
5656
```
5757

58-
2. Configure the callback, to trigger saves periodically.
58+
See the [system requirements](README.md#systemenvironment-requirements) for how to set up the filesystem to use.
59+
60+
2. Configure the callback to trigger saves periodically.
5961
```python
6062
# Add this callback to your Trainer's callbacks.
6163
callbacks.append(
@@ -77,13 +79,13 @@ auto_resume = nemo.lightning.AutoResume(...)
7779
# ML Flashpoint APIs below (and its AutoResume), to initialize replication.
7880
trainer.strategy.setup_environment()
7981

80-
# Wrap the trainer and AutoResume to configure ML Flashpoint.
82+
# Wrap the trainer and AutoResume to configure ML Flashpoint using the provided helper.
8183
auto_resume = wrap_trainer_and_auto_resume_with_mlflashpoint(
8284
trainer=trainer, # The PyTorch Lightning Trainer
8385
flashpoint_base_container=mlflashpoint_base_path,
8486
async_save=not args.sync_save,
8587
default_auto_resume=auto_resume, # Optional
86-
# always_save_context=True, # Optional, defaults to False
88+
# always_save_context=False, # Optional, defaults to False
8789
# write_thread_count=1, # Optional, defaults to 1
8890
# initial_write_buffer_size_bytes=DESIRED_NUM_BYTES, # Optional, defaults to 16 GB
8991
)
@@ -95,7 +97,7 @@ A complete recipe example that puts this all together can be found [here](http:/
9597

9698
Limitations:
9799

98-
1. Must use the `MegatronStrategy` as the strategy for your PyTorch Lightning Trainer.
100+
1. You must use the `MegatronStrategy` as the strategy for your PyTorch Lightning Trainer.
99101
Other strategies have not been tested.
100102
1. Ensure that the `base_container` for ML Flashpoint is job-specific (i.e. has a job ID in it), and on some ramdisk path (e.g. tmpfs).
101103
The job ID should be unique across jobs, but sticky (reused) when a job is interrupted and restarted/rescheduled (so it can recover from the latest checkpoint available for that particular job).
@@ -105,8 +107,79 @@ This reduces blocking time by avoiding duplicate work, at the cost of having a l
105107

106108
### Megatron-LM
107109

108-
Check out the `adapter/megatron` package.
110+
Code: See the `ml_flashpoint.adapter.megatron` package.
111+
112+
The Megatron strategies depend on the PyTorch DCP implementations.
113+
Below are instructions for setting up ML Flashpoint checkpointing, which you should configure alongside regular checkpointing to long-term storage.
114+
115+
#### Save Strategy
116+
117+
First create a `MemoryStorageWriter` instance as outlined in [PyTorch DCP](#pytorch-dcp).
118+
Then use that to instantiate the Megatron save strategy:
119+
120+
```python
121+
# Instantiate the MemoryStorageWriter
122+
memory_storage_writer = MemoryStorageWriter(...)
123+
124+
# Use it to instantiate the Save Strategy
125+
megatron_save_strategy = MLFlashpointMegatronAsyncSaveStrategy(
126+
storage_writer=memory_storage_writer,
127+
)
128+
```
129+
130+
Because Megatron's `dist_checkpointing.save()` function writes "common" data only on global rank 0, which does not align with local checkpointing, you can orchestrate saves using the save strategy the same way it's done in [`MLFlashpointCheckpointIO.save_checkpoint()`](https://github.com/google/ml-flashpoint/blob/b9767583520106f59743b9e8050769523cfbef6e/src/ml_flashpoint/adapter/nemo/checkpoint_io.py#L137-L171) in the `adapter.nemo` package.
131+
You'll notice that the logic there aims to mimic `dist_checkpointing.save`, but it saves common data on each node (via local rank 0) as opposed to solely on the coordinator node (global rank 0).
132+
133+
Use this strategy on a more frequent interval than your regular long-term storage checkpointing strategy.
134+
135+
#### Load Strategy
136+
137+
Instantiate the singleton `ReplicationManager` with a singleton `CheckpointObjectManager`, and make sure to `initialize()` the `ReplicationManager` before using it.
138+
Also create an `MLFlashpointCheckpointLoader` with those dependencies, and use these instances to create the load strategy:
139+
140+
```python
141+
# Initialize dependencies (shared singletons)
142+
checkpoint_object_manager = CheckpointObjectManager()
143+
replication_manager = ReplicationManager()
144+
replication_manager.initialize(checkpoint_object_manager)
145+
146+
checkpoint_loader = MLFlashpointCheckpointLoader(
147+
checkpoint_object_manager=checkpoint_object_manager,
148+
replication_manager=replication_manager,
149+
)
150+
151+
# Instantiate the Load Strategy with the dependencies
152+
mlflashpoint_load_strategy = MLFlashpointMegatronLoadStrategy(
153+
replication_manager=replication_manager,
154+
checkpoint_loader=checkpoint_loader,
155+
)
156+
```
157+
158+
Now you can use the load strategy with Megatron-LM's `dist_checkpointing.load` function directly:
159+
160+
```python
161+
# First determine if an ML Flashpoint checkpoint is available, using the base container path you've configured
162+
local_checkpoint_container = checkpoint_loader.get_latest_complete_checkpoint(checkpoint_base_container)
163+
164+
if local_container is None:
165+
# Load using your regular sharded strategy from your long-term storage path
166+
state_dict = mcore_dist_checkpointing.load(
167+
sharded_state_dict=sharded_state_dict,
168+
checkpoint_dir=str(long_term_storage_path),
169+
sharded_strategy=regular_megatron_load_strategy,
170+
common_strategy=TorchCommonLoadStrategy(),
171+
)
172+
else:
173+
# Given the existing load function doesn't do anything rank-specific,
174+
# it is suitable for us to use directly.
175+
state_dict = mcore_dist_checkpointing.load(
176+
sharded_state_dict=sharded_state_dict,
177+
checkpoint_dir=str(local_checkpoint_container),
178+
sharded_strategy=mlflashpoint_load_strategy,
179+
common_strategy=TorchCommonLoadStrategy(),
180+
)
181+
```
109182

110183
### PyTorch DCP
111184

112-
Check out the `adapter/pytorch` package.
185+
Code: See the `ml_flashpoint.adapter.pytorch` package.

mkdocs.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,15 @@ plugins:
4747
# - nav-weight # To allow specifying nav_weight: <int> at the top of a doc for ordering
4848
# To enable generating documentation that can be served from a local filesystem.
4949
# - offline
50+
- search
51+
- mkdocstrings:
52+
handlers:
53+
python:
54+
options:
55+
docstring_style: google # Matches project convention
56+
show_source: true
57+
show_root_heading: true
58+
show_root_full_path: false
5059

5160
markdown_extensions:
5261
- admonition

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ nemo = [
5959
docs = [
6060
# Documentation static-site-generator (mkdocs) with theme support.
6161
"mkdocs-material==9.7.0",
62+
# For collecting source code and docstrings into documentation.
63+
"mkdocstrings-python==2.0.1",
6264
# Documentation static-site-generator (mkdocs) with theme support. Successor of mkdocs-material.
6365
#"zensical==0.0.11",
6466
]

src/ml_flashpoint/adapter/nemo/checkpoint_io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def save_checkpoint(
127127
128128
Returns:
129129
An `Optional[MegatronAsyncRequest]` if `async_save` is `True` and the save is successful,
130-
otherwise `None`.
130+
otherwise `None`.
131131
"""
132132
if not _is_ml_flashpoint_checkpoint(self.flashpoint_base_dir, path):
133133
_LOGGER.info("Fallback to alternative checkpoint io.")

0 commit comments

Comments
 (0)