You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# always_save_context=True, # Optional, defaults to False
88
+
# always_save_context=False, # Optional, defaults to False
87
89
# write_thread_count=1, # Optional, defaults to 1
88
90
# initial_write_buffer_size_bytes=DESIRED_NUM_BYTES, # Optional, defaults to 16 GB
89
91
)
@@ -95,7 +97,7 @@ A complete recipe example that puts this all together can be found [here](http:/
95
97
96
98
Limitations:
97
99
98
-
1.Must use the `MegatronStrategy` as the strategy for your PyTorch Lightning Trainer.
100
+
1.You must use the `MegatronStrategy` as the strategy for your PyTorch Lightning Trainer.
99
101
Other strategies have not been tested.
100
102
1. Ensure that the `base_container` for ML Flashpoint is job-specific (i.e. has a job ID in it), and on some ramdisk path (e.g. tmpfs).
101
103
The job ID should be unique across jobs, but sticky (reused) when a job is interrupted and restarted/rescheduled (so it can recover from the latest checkpoint available for that particular job).
@@ -105,8 +107,79 @@ This reduces blocking time by avoiding duplicate work, at the cost of having a l
105
107
106
108
### Megatron-LM
107
109
108
-
Check out the `adapter/megatron` package.
110
+
Code: See the `ml_flashpoint.adapter.megatron` package.
111
+
112
+
The Megatron strategies depend on the PyTorch DCP implementations.
113
+
Below are instructions for setting up ML Flashpoint checkpointing, which you should configure alongside regular checkpointing to long-term storage.
114
+
115
+
#### Save Strategy
116
+
117
+
First create a `MemoryStorageWriter` instance as outlined in [PyTorch DCP](#pytorch-dcp).
118
+
Then use that to instantiate the Megatron save strategy:
Because Megatron's `dist_checkpointing.save()` function writes "common" data only on global rank 0, which does not align with local checkpointing, you can orchestrate saves using the save strategy the same way it's done in [`MLFlashpointCheckpointIO.save_checkpoint()`](https://github.com/google/ml-flashpoint/blob/b9767583520106f59743b9e8050769523cfbef6e/src/ml_flashpoint/adapter/nemo/checkpoint_io.py#L137-L171) in the `adapter.nemo` package.
131
+
You'll notice that the logic there aims to mimic `dist_checkpointing.save`, but it saves common data on each node (via local rank 0) as opposed to solely on the coordinator node (global rank 0).
132
+
133
+
Use this strategy on a more frequent interval than your regular long-term storage checkpointing strategy.
134
+
135
+
#### Load Strategy
136
+
137
+
Instantiate the singleton `ReplicationManager` with a singleton `CheckpointObjectManager`, and make sure to `initialize()` the `ReplicationManager` before using it.
138
+
Also create an `MLFlashpointCheckpointLoader` with those dependencies, and use these instances to create the load strategy:
0 commit comments