Skip to content

Commit 0da20a0

Browse files
committed
add warning for async save with fp8 params
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent d1f619e commit 0da20a0

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

bionemo-recipes/recipes/llama3_native_te/checkpoint.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@
3434
from torch.distributed.checkpoint.state_dict_saver import async_save as dcp_async_save
3535
from torch.distributed.checkpoint.state_dict_saver import save as dcp_save
3636
from torch.distributed.checkpoint.stateful import Stateful
37+
from torch.distributed.tensor import DTensor
3738
from torchdata.stateful_dataloader import StatefulDataLoader
39+
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
3840

3941
from distributed_config import DistributedConfig
4042

@@ -338,6 +340,13 @@ def save_checkpoint_fsdp2(
338340
checkpoint_path = ckpt_path / f"step_{step}"
339341
checkpoint_path.mkdir(parents=True, exist_ok=True)
340342

343+
model_params = (p.to_local() if isinstance(p, DTensor) else p for p in model.parameters())
344+
if async_save and any((isinstance(p, Float8Tensor) for p in model_params)):
345+
logger.warning(
346+
"Async checkpointing is not supported for FP8 models, falling back to synchronous checkpointing."
347+
)
348+
async_save = False
349+
341350
if dataloader is not None:
342351
save_dataloader(
343352
dataloader=dataloader,

bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,6 @@ def test_checkpoint_save_and_load_single_process_fsdp2_cp_fp8_quantized(recipe_p
531531
)
532532

533533

534-
@pytest.mark.xfail()
535534
def test_checkpoint_save_and_load_single_process_fsdp2_cp_fp8_quantized_async(recipe_path, tmp_path):
536535
"""Test checkpoint save/resume for FSDP2+CP with FP8 quantized model init and async save.
537536

0 commit comments

Comments
 (0)