Skip to content

Commit a2450dc

Browse files
authored
docs(site): add PyTorch DCP usage guide; segregate imports (#17)
1 parent e113677 commit a2450dc

File tree

3 files changed

+131
-26
lines changed

3 files changed

+131
-26
lines changed

docs/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ To use ML Flashpoint, the basic requirements for the training environment are:
5353
1. Python 3.10 or later.
5454
1. Linux operating system on the training nodes.
5555
1. An even number of training nodes, to use the pairwise replication strategy.
56-
This is enforced so that the pairwise strategy doesn't put a higher memory burden on one node than the others, and so the general capacity requirements are roughly consistent across nodes.
56+
* This is enforced so that the pairwise strategy doesn't put a higher memory burden on one node than the others, and so the general capacity requirements are roughly consistent across nodes.
5757
1. A `tmpfs` mount is strongly recommended to be used for the container base path, that is separate from `/dev/shm`.
5858
E.g. a `/tmp` mount, which can be added to `/etc/fstab` on Linux machines to mount it persistently (A3-Mega example):
5959
1. `tmpfs /tmp tmpfs rw,nosuid,nodev,size=1024G,mode=1777,noswap,huge=within_size 0 0`
@@ -63,13 +63,13 @@ E.g. a `/tmp` mount, which can be added to `/etc/fstab` on Linux machines to mou
6363
1. The amount of memory needed is at least equal to the checkpoint size per node x 4, to account for replicas and in-progress checkpoints.
6464
Typically, `/tmp` is set to 50% of host RAM (higher is OK).
6565
1. The base container specified for ML Flashpoint should be specific to the running job ID, which will store all checkpoints for that job, and will be used for recovery in that particular job.
66-
The job ID is important to include in the path because it ensures that different training jobs do not conflict, and that recovery is done correctly.
66+
* The job ID is important to include in the path because it ensures that different training jobs do not conflict, and that recovery is done correctly.
6767
* The assumption is that a new job ID is assigned for every new training job, and that it is reused when a job is resumed or re-queued due to an interruption.
6868
* The recovery logic typically (when configured correctly) always checks at job start whether some complete checkpoint is available in the job's checkpoint container, and if so will load it and resume from there.
6969
1. When a job recovers after some interruption, it should _reuse all the same machines_ it initially used that are still healthy, only replacing machines that need to be replaced.
7070
(If a process can be restarted without replacing the machine, recovery will be even quicker.)
71-
Given checkpointing state is kept in-memory, this is essential to take advantage of ML Flashpoint checkpoints and be able to recover from them.
72-
If the job is resumed or re-queued on a different set of nodes, or with a different job ID, there will be no ML Flashpoint state to recover from, forcing a fallback to the long-term storage checkpoints, which is slower.
71+
* Given checkpointing state is kept in-memory, this is essential to take advantage of ML Flashpoint checkpoints and be able to recover from them.
72+
* If the job is resumed or re-queued on a different set of nodes, or with a different job ID, there will be no ML Flashpoint state to recover from, forcing a fallback to the long-term storage checkpoints, which is slower.
7373

7474
## Framework Layers
7575

docs/user-guide.md

Lines changed: 126 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -112,17 +112,31 @@ Code: See the [`ml_flashpoint.adapter.megatron`](https://github.com/google/ml-fl
112112
The Megatron strategies depend on the PyTorch DCP implementations.
113113
Below are instructions for setting up ML Flashpoint checkpointing, which you should configure alongside regular checkpointing to long-term storage.
114114

115-
#### Save Strategy
116-
117-
First create a `MemoryStorageWriter` instance as outlined in [PyTorch DCP](#pytorch-dcp).
118-
Then use that to instantiate the Megatron save strategy.
115+
#### Imports
119116

120117
```python
118+
# Saving
121119
from ml_flashpoint.adapter.pytorch.memory_storage_writer import MemoryStorageWriter
122120
from ml_flashpoint.adapter.megatron.save_strategies import (
123121
MLFlashpointMegatronAsyncSaveStrategy,
124122
)
125123

124+
# Loading
125+
from ml_flashpoint.adapter.megatron.load_strategies import MLFlashpointMegatronLoadStrategy
126+
from ml_flashpoint.checkpoint_object_manager.checkpoint_object_manager import CheckpointObjectManager
127+
from ml_flashpoint.core.checkpoint_loader import DefaultMLFlashpointCheckpointLoader
128+
from ml_flashpoint.replication.replication_manager import ReplicationManager
129+
130+
# Megatron Checkpointing
131+
from megatron.core import dist_checkpointing as mcore_dist_checkpointing
132+
```
133+
134+
#### Save Strategy
135+
136+
First create a `MemoryStorageWriter` instance as outlined in [PyTorch DCP](#pytorch-dcp).
137+
Then use that to instantiate the Megatron save strategy.
138+
139+
```python
126140
# Instantiate the MemoryStorageWriter
127141
memory_storage_writer = MemoryStorageWriter(...)
128142

@@ -135,6 +149,13 @@ megatron_save_strategy = MLFlashpointMegatronAsyncSaveStrategy(
135149
Because Megatron's `dist_checkpointing.save()` function writes "common" data only on global rank 0, which does not align with local checkpointing, you can orchestrate saves using the save strategy the same way it's done in [`MLFlashpointCheckpointIO.save_checkpoint()`](https://github.com/google/ml-flashpoint/blob/b9767583520106f59743b9e8050769523cfbef6e/src/ml_flashpoint/adapter/nemo/checkpoint_io.py#L137-L171) in the `ml_flashpoint.adapter.nemo` package.
136150
You'll notice that the logic there aims to mimic `dist_checkpointing.save`, but it saves common data on each node (via local rank 0) as opposed to solely on the coordinator node (global rank 0).
137151

152+
!!! note
153+
154+
Make sure to specify the checkpoint ID/path when saving based on the current step using:
155+
`CheckpointContainerId.create_child(base_container, CheckpointContainerId.format_version_container(current_step))`
156+
where `base_container` is the base path CheckpointContainerId used for all checkpoints for the current job, e.g. `"/tmp/mlf-checkpoints/job123"`.
157+
158+
138159
Use this strategy on a more frequent interval than your regular long-term storage checkpointing strategy.
139160

140161
#### Load Strategy
@@ -143,12 +164,6 @@ Instantiate the singleton `ReplicationManager` with a singleton `CheckpointObjec
143164
Also create an `MLFlashpointCheckpointLoader` with those dependencies, and use these instances to create the load strategy:
144165

145166
```python
146-
from ml_flashpoint.adapter.megatron.load_strategies import MLFlashpointMegatronLoadStrategy
147-
from ml_flashpoint.checkpoint_object_manager.checkpoint_object_manager import CheckpointObjectManager
148-
from ml_flashpoint.core.checkpoint_loader import DefaultMLFlashpointCheckpointLoader
149-
from ml_flashpoint.replication.replication_manager import ReplicationManager
150-
151-
152167
# Initialize dependencies (shared singletons)
153168
checkpoint_object_manager = CheckpointObjectManager()
154169
replication_manager = ReplicationManager()
@@ -169,30 +184,120 @@ mlflashpoint_load_strategy = MLFlashpointMegatronLoadStrategy(
169184
Now you can use the load strategy with Megatron-LM's `dist_checkpointing.load` function directly:
170185

171186
```python
172-
from megatron.core import dist_checkpointing as mcore_dist_checkpointing
173-
174187
# First determine if an ML Flashpoint checkpoint is available, using the base container path you've configured
175-
local_checkpoint_container = checkpoint_loader.get_latest_complete_checkpoint(checkpoint_base_container)
188+
latest_saved_checkpoint_id = checkpoint_loader.get_latest_complete_checkpoint(checkpoint_base_container)
176189

177-
if local_container is None:
178-
# Load using your regular sharded strategy from your long-term storage path
190+
if local_checkpoint_container:
191+
# Given the existing load function doesn't do anything rank-specific,
192+
# it is suitable for us to use directly.
179193
state_dict = mcore_dist_checkpointing.load(
180194
sharded_state_dict=sharded_state_dict,
181-
checkpoint_dir=str(long_term_storage_path),
182-
sharded_strategy=regular_megatron_load_strategy,
195+
checkpoint_dir=str(latest_saved_checkpoint_id),
196+
sharded_strategy=mlflashpoint_load_strategy,
183197
common_strategy=TorchCommonLoadStrategy(),
184198
)
185199
else:
186-
# Given the existing load function doesn't do anything rank-specific,
187-
# it is suitable for us to use directly.
200+
# Load using your regular sharded strategy from your long-term storage path
188201
state_dict = mcore_dist_checkpointing.load(
189202
sharded_state_dict=sharded_state_dict,
190-
checkpoint_dir=str(local_checkpoint_container),
191-
sharded_strategy=mlflashpoint_load_strategy,
203+
checkpoint_dir=str(long_term_storage_path),
204+
sharded_strategy=regular_megatron_load_strategy,
192205
common_strategy=TorchCommonLoadStrategy(),
193206
)
194207
```
195208

196209
### PyTorch DCP
197210

198211
Code: See the [`ml_flashpoint.adapter.pytorch`](https://github.com/google/ml-flashpoint/tree/main/src/ml_flashpoint/adapter/pytorch) package.
212+
213+
To use directly with PyTorch DCP, use the provided `StorageWriter` and `StorageReader` implementations.
214+
You can use whatever `Planner` implementations work for your use case, or resort to the defaults.
215+
216+
If your per-rank checkpoint data exceeds the default buffer size (16 GB as of this writing), you can increase it using the optional `initial_buffer_size_bytes` parameter.
217+
218+
#### Imports
219+
```python
220+
import torch
221+
from torch import multiprocessing as torch_mp
222+
import torch.distributed.checkpoint as dcp
223+
224+
from ml_flashpoint.adapter.megatron.load_strategies import MLFlashpointMegatronLoadStrategy
225+
from ml_flashpoint.adapter.pytorch.memory_storage_writer import MemoryStorageWriter
226+
from ml_flashpoint.checkpoint_object_manager.checkpoint_object_manager import CheckpointObjectManager
227+
from ml_flashpoint.core.checkpoint_id_types import CheckpointContainerId
228+
from ml_flashpoint.core.checkpoint_loader import DefaultMLFlashpointCheckpointLoader
229+
from ml_flashpoint.core.checkpoint_saver import DefaultMLFlashpointCheckpointSaver
230+
from ml_flashpoint.replication.replication_manager import ReplicationManager
231+
```
232+
233+
#### Initialization
234+
```python
235+
# Initialize dependencies (shared singletons)
236+
checkpoint_object_manager = CheckpointObjectManager()
237+
replication_manager = ReplicationManager()
238+
replication_manager.initialize(checkpoint_object_manager)
239+
240+
# Instantiate the StorageWriter
241+
memory_storage_writer = MemoryStorageWriter(
242+
checkpoint_saver=DefaultMLFlashpointCheckpointSaver(
243+
global_rank_getter=torch.distributed.get_rank,
244+
local_rank_getter=torch.distributed.get_node_local_rank,
245+
global_barrier_func=lambda: torch.distributed.barrier(),
246+
ckpt_obj_manager=checkpoint_object_manager,
247+
replication_manager=replication_manager,
248+
# initial_buffer_size_bytes=initial_write_buffer_size_bytes, # Optional - increase for larger checkpoint sizes per rank
249+
),
250+
mp_manager=torch_mp.Manager(),
251+
)
252+
253+
# Instantiate the CheckpointLoader and StorageReader
254+
checkpoint_loader = DefaultMLFlashpointCheckpointLoader(
255+
checkpoint_object_manager=checkpoint_object_manager,
256+
replication_manager=replication_manager,
257+
)
258+
memory_storage_reader = MemoryStorageReader(
259+
path=checkpoint_dir,
260+
checkpoint_loader=checkpoint_loader,
261+
)
262+
```
263+
264+
#### Saving
265+
266+
Now you can use the `MemoryStorageWriter` when saving checkpoints as you normally do with DCP e.g.:
267+
268+
```python
269+
# Assuming base_container is the base path CheckpointContainerId used for all checkpoints for the current job, e.g. `"/tmp/mlf-checkpoints/job123"`:
270+
curr_step_checkpoint_id = CheckpointContainerId.create_child(
271+
base_container, CheckpointContainerId.format_version_container(current_step)
272+
)
273+
274+
# Sync save
275+
metadata = dcp.save(state_dict,
276+
checkpoint_id=str(curr_step_checkpoint_id),
277+
storage_writer=memory_storage_writer)
278+
279+
# Async save
280+
future = dcp.async_save(state_dict,
281+
checkpoint_id=str(curr_step_checkpoint_id),
282+
storage_writer=memory_storage_writer,
283+
async_checkpointer_type=dcp.AsyncCheckpointerType.PROCESS)
284+
```
285+
286+
#### Recovery
287+
During a recovery scenario, use the `checkpoint_loader` to first identify the latest available ML Flashpoint checkpoint, if any, to recover from.
288+
If none, fallback to your long-term storage checkpoint.
289+
290+
```python
291+
# First determine if an ML Flashpoint checkpoint is available, using the base container path you've configured
292+
latest_saved_checkpoint_id = checkpoint_loader.get_latest_complete_checkpoint(checkpoint_base_container)
293+
294+
if latest_saved_checkpoint_id:
295+
dcp.load(state_dict,
296+
checkpoint_id=str(latest_saved_checkpoint_id),
297+
storage_reader=memory_storage_reader)
298+
else:
299+
# Load using your regular sharded strategy from your long-term storage path
300+
dcp.load(state_dict,
301+
checkpoint_id=str(long_term_checkpoint_path),
302+
...)
303+
```

src/ml_flashpoint/core/checkpoint_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def get_latest_complete_checkpoint(
292292
self, checkpoint_base_container: CheckpointContainerId
293293
) -> Optional[CheckpointContainerId]:
294294
"""
295-
Step 1: call get_candidate_checkpoints to get all existing checkpoint containers cross
295+
Step 1: call get_candidate_checkpoints to get all existing checkpoint containers across
296296
all ranks as candidates and sorted in a descending order by step
297297
Step 2: traverse the candidate checkpoints and for each checkpoint, for each candidate:
298298
- call get_checkpoint_objects_by_rank to get all existing checkpoint objects cross

0 commit comments

Comments
 (0)