Skip to content

Commit 7411320

Browse files
committed
test distributed checkpoing with different recipes
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 0da20a0 commit 7411320

File tree

2 files changed

+59
-7
lines changed

2 files changed

+59
-7
lines changed

bionemo-recipes/recipes/llama3_native_te/tests/conftest.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import pytest
2121
import torch
22+
from transformer_engine.pytorch import fp8 as te_fp8
2223

2324

2425
sys.path.append(Path(__file__).parent.parent.as_posix())
@@ -61,6 +62,56 @@ def pytest_collection_modifyitems(items):
6162
items[:] = stats_tests + other_tests
6263

6364

65+
# ---------------------------------------------------------------------------
66+
# FP8 recipe parametrization
67+
# ---------------------------------------------------------------------------
68+
69+
# Each entry: (recipe_class_name, hydra_overrides, check_fn)
70+
_FP8_RECIPE_CONFIGS = [
71+
(
72+
"DelayedScaling",
73+
["fp8_config.fp8_recipe=transformer_engine.common.recipe.DelayedScaling"],
74+
te_fp8.check_fp8_support,
75+
),
76+
(
77+
"Float8CurrentScaling",
78+
["fp8_config.fp8_recipe=transformer_engine.common.recipe.Float8CurrentScaling"],
79+
te_fp8.check_fp8_support,
80+
),
81+
(
82+
"Float8BlockScaling",
83+
["fp8_config.fp8_recipe=transformer_engine.common.recipe.Float8BlockScaling"],
84+
te_fp8.check_fp8_block_scaling_support,
85+
),
86+
(
87+
"MXFP8BlockScaling",
88+
["fp8_config.fp8_recipe=transformer_engine.common.recipe.MXFP8BlockScaling"],
89+
te_fp8.check_mxfp8_support,
90+
),
91+
]
92+
93+
94+
def _parametrize_fp8_recipes():
95+
"""Generate pytest.param objects with xfail marks for unsupported FP8 recipes."""
96+
params = []
97+
for name, overrides, check_fn in _FP8_RECIPE_CONFIGS:
98+
supported, reason = check_fn()
99+
params.append(
100+
pytest.param(
101+
overrides,
102+
id=name,
103+
marks=pytest.mark.xfail(condition=not supported, reason=reason),
104+
)
105+
)
106+
return params
107+
108+
109+
@pytest.fixture(params=_parametrize_fp8_recipes())
110+
def fp_recipe(request):
111+
"""Parametrized fixture providing FP8 recipe Hydra overrides for each supported TE recipe."""
112+
return request.param
113+
114+
64115
@pytest.fixture(scope="session", autouse=True)
65116
def device_mesh():
66117
"""Create a re-usable torch process group for testing.

bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -492,46 +492,46 @@ def test_checkpoint_pruning_with_files(tmp_path):
492492
]
493493

494494

495-
def test_checkpoint_save_and_load_single_process_ddp_fp8_quantized(recipe_path, tmp_path):
495+
def test_checkpoint_save_and_load_single_process_ddp_fp8_quantized(recipe_path, tmp_path, fp_recipe):
496496
"""Test checkpoint save/resume for DDP with FP8 quantized model init."""
497497
_run_single_process_checkpoint_test(
498498
recipe_path,
499499
tmp_path,
500500
main_ddp,
501501
ckpt_subdir_name="train_ddp",
502502
config_name="L0_sanity_cp",
503-
extra_overrides=_FP8_QUANTIZED_OVERRIDES,
503+
extra_overrides=[*_FP8_QUANTIZED_OVERRIDES, *fp_recipe],
504504
is_ddp=True,
505505
)
506506

507507

508-
def test_checkpoint_save_and_load_single_process_fsdp2_fp8_quantized(recipe_path, tmp_path):
508+
def test_checkpoint_save_and_load_single_process_fsdp2_fp8_quantized(recipe_path, tmp_path, fp_recipe):
509509
"""Test checkpoint save/resume for FSDP2 with FP8 quantized model init."""
510510
_run_single_process_checkpoint_test(
511511
recipe_path,
512512
tmp_path,
513513
main_fsdp2,
514514
ckpt_subdir_name="train_fsdp2",
515515
config_name="L0_sanity_cp",
516-
extra_overrides=_FP8_QUANTIZED_OVERRIDES,
516+
extra_overrides=[*_FP8_QUANTIZED_OVERRIDES, *fp_recipe],
517517
is_ddp=False,
518518
)
519519

520520

521-
def test_checkpoint_save_and_load_single_process_fsdp2_cp_fp8_quantized(recipe_path, tmp_path):
521+
def test_checkpoint_save_and_load_single_process_fsdp2_cp_fp8_quantized(recipe_path, tmp_path, fp_recipe):
522522
"""Test checkpoint save/resume for FSDP2 with context parallelism and FP8 quantized model init."""
523523
_run_single_process_checkpoint_test(
524524
recipe_path,
525525
tmp_path,
526526
main_fsdp2_cp,
527527
ckpt_subdir_name="train_fsdp2",
528528
config_name="L0_sanity_cp",
529-
extra_overrides=_FP8_QUANTIZED_OVERRIDES,
529+
extra_overrides=[*_FP8_QUANTIZED_OVERRIDES, *fp_recipe],
530530
is_ddp=False,
531531
)
532532

533533

534-
def test_checkpoint_save_and_load_single_process_fsdp2_cp_fp8_quantized_async(recipe_path, tmp_path):
534+
def test_checkpoint_save_and_load_single_process_fsdp2_cp_fp8_quantized_async(recipe_path, tmp_path, fp_recipe):
535535
"""Test checkpoint save/resume for FSDP2+CP with FP8 quantized model init and async save.
536536
537537
This reproduces the corys_config scenario where async_save=true (the default)
@@ -545,6 +545,7 @@ def test_checkpoint_save_and_load_single_process_fsdp2_cp_fp8_quantized_async(re
545545
config_name="L0_sanity_cp",
546546
extra_overrides=[
547547
*_FP8_QUANTIZED_OVERRIDES,
548+
*fp_recipe,
548549
"checkpoint.async_save=true",
549550
],
550551
is_ddp=False,

0 commit comments

Comments
 (0)