Skip to content

Commit 9467487

Browse files
authored
Checkpointing cleanup (esm2/llama3) (#1366)
Lots of misc edits and improvements to checkpointing in the esm2 and llama3 training recipes. * adds the ability to only keep the last N checkpoints for a training run * adds async fsdp2 checkpointing, although this currently seems buggy. * adds a cache cleanup step to stop us from OOM errors when training Llama3 with torch distributed checkpointing. --------- Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 1071105 commit 9467487

File tree

15 files changed

+194
-185
lines changed

15 files changed

+194
-185
lines changed

bionemo-recipes/recipes/esm2_native_te/checkpoint.py

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import gc
1617
import logging
1718
import os
19+
import shutil
1820
from dataclasses import dataclass, field
1921
from pathlib import Path
2022
from typing import NamedTuple
2123

2224
import torch
23-
import torch.distributed.checkpoint as dcp
2425
import transformers
2526
from safetensors.torch import save_file
2627
from torch.distributed.checkpoint.state_dict import (
@@ -29,13 +30,17 @@
2930
get_state_dict,
3031
set_state_dict,
3132
)
33+
from torch.distributed.checkpoint.state_dict_loader import load as dcp_load
34+
from torch.distributed.checkpoint.state_dict_saver import async_save as dcp_async_save
35+
from torch.distributed.checkpoint.state_dict_saver import save as dcp_save
3236
from torch.distributed.checkpoint.stateful import Stateful
3337
from torchdata.stateful_dataloader import StatefulDataLoader
3438

3539
from distributed_config import DistributedConfig
3640

3741

3842
logger = logging.getLogger(__name__)
43+
_ckpt_futures: dict = {}
3944

4045

4146
class CheckpointOutput(NamedTuple):
@@ -82,6 +87,20 @@ def should_save_checkpoint(step: int, save_every_n_steps: int) -> bool:
8287
return False
8388

8489

90+
def prune_checkpoints(ckpt_path: str | os.PathLike, max_checkpoints: int) -> None:
91+
"""Prune checkpoints to keep only the latest `max_checkpoints` checkpoints."""
92+
ckpt_path = Path(ckpt_path)
93+
checkpoints = [f for f in ckpt_path.iterdir() if f.name.startswith("step_")]
94+
checkpoints.sort(key=lambda x: int(Path(x).stem.split("_")[1]))
95+
if len(checkpoints) > max_checkpoints:
96+
for checkpoint in checkpoints[:-max_checkpoints]:
97+
logger.info(f"Pruning checkpoint {checkpoint}")
98+
if checkpoint.is_dir():
99+
shutil.rmtree(checkpoint)
100+
else:
101+
os.remove(checkpoint)
102+
103+
85104
# ============================================================================
86105
# DDP Checkpointing
87106
# ============================================================================
@@ -131,6 +150,7 @@ def save_checkpoint_ddp(
131150
epoch: int,
132151
dist_config: DistributedConfig,
133152
dataloader: StatefulDataLoader | None = None,
153+
max_checkpoints: int | None = None,
134154
) -> None:
135155
"""Saves the Dataloader state and the DDP checkpoint."""
136156
ckpt_path = Path(ckpt_path)
@@ -157,6 +177,9 @@ def save_checkpoint_ddp(
157177

158178
logger.info(f"Saved DDP checkpoint to {checkpoint_path}")
159179

180+
if max_checkpoints is not None and dist_config.is_main_process():
181+
prune_checkpoints(ckpt_path, max_checkpoints)
182+
160183

161184
def save_final_model_ddp(
162185
model: torch.nn.Module,
@@ -243,6 +266,7 @@ def save_checkpoint_mfsdp(
243266
dist_config: DistributedConfig,
244267
dataloader: StatefulDataLoader | None = None,
245268
epoch: int = 0,
269+
max_checkpoints: int | None = None,
246270
) -> None:
247271
"""Save mFSDP distributed checkpoint.
248272
@@ -255,6 +279,7 @@ def save_checkpoint_mfsdp(
255279
dist_config: The distributed configuration.
256280
dataloader: The dataloader to save.
257281
epoch: The epoch number to save the checkpoint.
282+
max_checkpoints: The maximum number of checkpoints to keep.
258283
"""
259284
ckpt_path = Path(ckpt_path)
260285
checkpoint_path = ckpt_path / f"step_{step}"
@@ -279,6 +304,9 @@ def save_checkpoint_mfsdp(
279304
if dist_config.is_main_process():
280305
logger.info(f"Saved mFSDP checkpoint to {checkpoint_path}")
281306

307+
if max_checkpoints is not None and dist_config.is_main_process():
308+
prune_checkpoints(ckpt_path, max_checkpoints)
309+
282310

283311
def save_final_model_mfsdp(
284312
model: torch.nn.Module,
@@ -369,6 +397,7 @@ def load_checkpoint_fsdp2(
369397
ckpt_path: str | os.PathLike,
370398
dist_config: DistributedConfig,
371399
dataloader: StatefulDataLoader | None = None,
400+
process_group: torch.distributed.ProcessGroup | None = None,
372401
) -> CheckpointOutput:
373402
"""Load FSDP2 checkpoint.
374403
@@ -379,6 +408,7 @@ def load_checkpoint_fsdp2(
379408
ckpt_path: The directory containing checkpoints.
380409
dist_config: The distributed configuration.
381410
dataloader: The dataloader to load.
411+
process_group: The process group to use for checkpointing.
382412
"""
383413
checkpoint_path, _ = get_latest_checkpoint(ckpt_path)
384414
if not checkpoint_path:
@@ -392,7 +422,7 @@ def load_checkpoint_fsdp2(
392422
)
393423

394424
state_dict = {"app": app_state}
395-
dcp.load(state_dict, checkpoint_id=checkpoint_path)
425+
dcp_load(state_dict, checkpoint_id=checkpoint_path, process_group=process_group)
396426

397427
if dataloader is not None:
398428
load_dataloader(
@@ -416,6 +446,9 @@ def save_checkpoint_fsdp2(
416446
epoch: int,
417447
dist_config: DistributedConfig,
418448
dataloader: StatefulDataLoader | None = None,
449+
process_group: torch.distributed.ProcessGroup | None = None,
450+
max_checkpoints: int | None = None,
451+
async_save: bool = False,
419452
) -> None:
420453
"""Save FSDP2 checkpoint.
421454
@@ -428,6 +461,9 @@ def save_checkpoint_fsdp2(
428461
epoch: The epoch number to save the checkpoint.
429462
dist_config: The distributed configuration.
430463
dataloader: The dataloader to save.
464+
process_group: The process group to use for checkpointing.
465+
max_checkpoints: The maximum number of checkpoints to keep.
466+
async_save: Whether to save the checkpoint asynchronously.
431467
"""
432468
ckpt_path = Path(ckpt_path)
433469
checkpoint_path = ckpt_path / f"step_{step}"
@@ -441,17 +477,24 @@ def save_checkpoint_fsdp2(
441477
)
442478
logger.info(f"Saved FSDP2 dataloader to {ckpt_path}")
443479

444-
state_dict = {
445-
"app": AppState(
446-
model=model,
447-
optimizer=optimizer,
448-
scheduler=scheduler,
449-
step=step,
450-
epoch=epoch,
451-
)
452-
}
453-
dcp.save(state_dict=state_dict, checkpoint_id=checkpoint_path)
454-
logger.info(f"Saved distributed FSDP2 checkpoint to {checkpoint_path}")
480+
# If we're using asynchronous checkpointing, make sure we only have one checkpoint future at a time.
481+
if async_save and "fsdp2" in _ckpt_futures and _ckpt_futures["fsdp2"] is not None:
482+
_ckpt_futures["fsdp2"].result()
483+
484+
# Clear GPU cache before checkpointing to free up fragmented memory.
485+
gc.collect()
486+
torch.cuda.empty_cache()
487+
torch.distributed.barrier(group=process_group)
488+
489+
state_dict = {"app": AppState(model=model, optimizer=optimizer, scheduler=scheduler, step=step, epoch=epoch)}
490+
ckpt_save_func = dcp_async_save if async_save else dcp_save
491+
_ckpt_futures["fsdp2"] = ckpt_save_func(state_dict, checkpoint_id=checkpoint_path, process_group=process_group)
492+
493+
if dist_config.is_main_process():
494+
logger.info(f"Saved distributed FSDP2 checkpoint to {checkpoint_path}")
495+
496+
if max_checkpoints is not None and dist_config.is_main_process():
497+
prune_checkpoints(ckpt_path, max_checkpoints)
455498

456499

457500
def save_final_model_fsdp2(

bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ checkpoint:
7070
save_final_model: true
7171
resume_from_checkpoint: true
7272
save_every_n_steps: 1_000
73+
max_checkpoints: 5 # Keep only the latest 5 checkpoints
74+
async_save: true # Whether to save the checkpoint asynchronously, currently only supported with FSDP2.
7375

7476
logger:
7577
frequency: 100

bionemo-recipes/recipes/esm2_native_te/train_ddp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,10 @@ def main(args: DictConfig) -> float | None:
157157
scheduler=scheduler,
158158
ckpt_path=ckpt_path,
159159
step=step,
160+
epoch=epoch,
160161
dist_config=dist_config,
161162
dataloader=train_dataloader if args.dataset.use_stateful_dataloader else None,
162-
epoch=epoch,
163+
max_checkpoints=args.checkpoint.max_checkpoints,
163164
)
164165

165166
step += 1

bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,10 @@ def main(args: DictConfig) -> float | None:
188188
scheduler=scheduler,
189189
ckpt_path=ckpt_path,
190190
step=step,
191-
dist_config=dist_config,
192-
dataloader=train_dataloader,
193191
epoch=epoch,
192+
dist_config=dist_config,
193+
dataloader=train_dataloader if args.dataset.use_stateful_dataloader else None,
194+
max_checkpoints=args.checkpoint.max_checkpoints,
194195
)
195196

196197
step += 1

bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def main(args: DictConfig) -> float | None:
166166
epoch=epoch,
167167
dist_config=dist_config,
168168
dataloader=train_dataloader if args.dataset.use_stateful_dataloader else None,
169+
max_checkpoints=args.checkpoint.max_checkpoints,
169170
)
170171

171172
step += 1

bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def main(args: DictConfig) -> float | None:
207207
epoch=epoch,
208208
dist_config=dist_config,
209209
dataloader=train_dataloader if args.dataset.use_stateful_dataloader else None,
210+
max_checkpoints=args.checkpoint.max_checkpoints,
210211
)
211212

212213
step += 1

bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,10 @@ def main(args: DictConfig) -> float | None:
175175
scheduler=scheduler,
176176
ckpt_path=ckpt_path,
177177
step=step,
178+
epoch=epoch,
178179
dist_config=dist_config,
179180
dataloader=train_dataloader if args.dataset.use_stateful_dataloader else None,
180-
epoch=epoch,
181+
max_checkpoints=args.checkpoint.max_checkpoints,
181182
)
182183

183184
step += 1

0 commit comments

Comments
 (0)