You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/README.md
+4-4Lines changed: 4 additions & 4 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -53,7 +53,7 @@ To use ML Flashpoint, the basic requirements for the training environment are:
53
53
1. Python 3.10 or later.
54
54
1. Linux operating system on the training nodes.
55
55
1. An even number of training nodes, to use the pairwise replication strategy.
56
-
This is enforced so that the pairwise strategy doesn't put a higher memory burden on one node than the others, and so the general capacity requirements are roughly consistent across nodes.
56
+
*This is enforced so that the pairwise strategy doesn't put a higher memory burden on one node than the others, and so the general capacity requirements are roughly consistent across nodes.
57
57
1. A `tmpfs` mount is strongly recommended to be used for the container base path, that is separate from `/dev/shm`.
58
58
E.g. a `/tmp` mount, which can be added to `/etc/fstab` on Linux machines to mount it persistently (A3-Mega example):
@@ -63,13 +63,13 @@ E.g. a `/tmp` mount, which can be added to `/etc/fstab` on Linux machines to mou
63
63
1. The amount of memory needed is at least equal to the checkpoint size per node x 4, to account for replicas and in-progress checkpoints.
64
64
Typically, `/tmp` is set to 50% of host RAM (higher is OK).
65
65
1. The base container specified for ML Flashpoint should be specific to the running job ID, which will store all checkpoints for that job, and will be used for recovery in that particular job.
66
-
The job ID is important to include in the path because it ensures that different training jobs do not conflict, and that recovery is done correctly.
66
+
*The job ID is important to include in the path because it ensures that different training jobs do not conflict, and that recovery is done correctly.
67
67
* The assumption is that a new job ID is assigned for every new training job, and that it is reused when a job is resumed or re-queued due to an interruption.
68
68
* The recovery logic typically (when configured correctly) always checks at job start whether some complete checkpoint is available in the job's checkpoint container, and if so will load it and resume from there.
69
69
1. When a job recovers after some interruption, it should _reuse all the same machines_ it initially used that are still healthy, only replacing machines that need to be replaced.
70
70
(If a process can be restarted without replacing the machine, recovery will be even quicker.)
71
-
Given checkpointing state is kept in-memory, this is essential to take advantage of ML Flashpoint checkpoints and be able to recover from them.
72
-
If the job is resumed or re-queued on a different set of nodes, or with a different job ID, there will be no ML Flashpoint state to recover from, forcing a fallback to the long-term storage checkpoints, which is slower.
71
+
* Given checkpointing state is kept in-memory, this is essential to take advantage of ML Flashpoint checkpoints and be able to recover from them.
72
+
* If the job is resumed or re-queued on a different set of nodes, or with a different job ID, there will be no ML Flashpoint state to recover from, forcing a fallback to the long-term storage checkpoints, which is slower.
Because Megatron's `dist_checkpointing.save()` function writes "common" data only on global rank 0, which does not align with local checkpointing, you can orchestrate saves using the save strategy the same way it's done in [`MLFlashpointCheckpointIO.save_checkpoint()`](https://github.com/google/ml-flashpoint/blob/b9767583520106f59743b9e8050769523cfbef6e/src/ml_flashpoint/adapter/nemo/checkpoint_io.py#L137-L171) in the `ml_flashpoint.adapter.nemo` package.
136
150
You'll notice that the logic there aims to mimic `dist_checkpointing.save`, but it saves common data on each node (via local rank 0) as opposed to solely on the coordinator node (global rank 0).
137
151
152
+
!!! note
153
+
154
+
Make sure to specify the checkpoint ID/path when saving based on the current step using:
# Load using your regular sharded strategy from your long-term storage path
190
+
if local_checkpoint_container:
191
+
# Given the existing load function doesn't do anything rank-specific,
192
+
# it is suitable for us to use directly.
179
193
state_dict = mcore_dist_checkpointing.load(
180
194
sharded_state_dict=sharded_state_dict,
181
-
checkpoint_dir=str(long_term_storage_path),
182
-
sharded_strategy=regular_megatron_load_strategy,
195
+
checkpoint_dir=str(latest_saved_checkpoint_id),
196
+
sharded_strategy=mlflashpoint_load_strategy,
183
197
common_strategy=TorchCommonLoadStrategy(),
184
198
)
185
199
else:
186
-
# Given the existing load function doesn't do anything rank-specific,
187
-
# it is suitable for us to use directly.
200
+
# Load using your regular sharded strategy from your long-term storage path
188
201
state_dict = mcore_dist_checkpointing.load(
189
202
sharded_state_dict=sharded_state_dict,
190
-
checkpoint_dir=str(local_checkpoint_container),
191
-
sharded_strategy=mlflashpoint_load_strategy,
203
+
checkpoint_dir=str(long_term_storage_path),
204
+
sharded_strategy=regular_megatron_load_strategy,
192
205
common_strategy=TorchCommonLoadStrategy(),
193
206
)
194
207
```
195
208
196
209
### PyTorch DCP
197
210
198
211
Code: See the [`ml_flashpoint.adapter.pytorch`](https://github.com/google/ml-flashpoint/tree/main/src/ml_flashpoint/adapter/pytorch) package.
212
+
213
+
To use directly with PyTorch DCP, use the provided `StorageWriter` and `StorageReader` implementations.
214
+
You can use whatever `Planner` implementations work for your use case, or resort to the defaults.
215
+
216
+
If your per-rank checkpoint data exceeds the default buffer size (16 GB as of this writing), you can increase it using the optional `initial_buffer_size_bytes` parameter.
217
+
218
+
#### Imports
219
+
```python
220
+
import torch
221
+
from torch import multiprocessing as torch_mp
222
+
import torch.distributed.checkpoint as dcp
223
+
224
+
from ml_flashpoint.adapter.megatron.load_strategies import MLFlashpointMegatronLoadStrategy
225
+
from ml_flashpoint.adapter.pytorch.memory_storage_writer import MemoryStorageWriter
226
+
from ml_flashpoint.checkpoint_object_manager.checkpoint_object_manager import CheckpointObjectManager
227
+
from ml_flashpoint.core.checkpoint_id_types import CheckpointContainerId
228
+
from ml_flashpoint.core.checkpoint_loader import DefaultMLFlashpointCheckpointLoader
229
+
from ml_flashpoint.core.checkpoint_saver import DefaultMLFlashpointCheckpointSaver
230
+
from ml_flashpoint.replication.replication_manager import ReplicationManager
0 commit comments