File tree Expand file tree Collapse file tree 2 files changed +9
-1
lines changed
bionemo-recipes/recipes/llama3_native_te Expand file tree Collapse file tree 2 files changed +9
-1
lines changed Original file line number Diff line number Diff line change 3434from torch .distributed .checkpoint .state_dict_saver import async_save as dcp_async_save
3535from torch .distributed .checkpoint .state_dict_saver import save as dcp_save
3636from torch .distributed .checkpoint .stateful import Stateful
37+ from torch .distributed .tensor import DTensor
3738from torchdata .stateful_dataloader import StatefulDataLoader
39+ from transformer_engine .pytorch .tensor .float8_tensor import Float8Tensor
3840
3941from distributed_config import DistributedConfig
4042
@@ -338,6 +340,13 @@ def save_checkpoint_fsdp2(
338340 checkpoint_path = ckpt_path / f"step_{ step } "
339341 checkpoint_path .mkdir (parents = True , exist_ok = True )
340342
343+ model_params = (p .to_local () if isinstance (p , DTensor ) else p for p in model .parameters ())
344+ if async_save and any ((isinstance (p , Float8Tensor ) for p in model_params )):
345+ logger .warning (
346+ "Async checkpointing is not supported for FP8 models, falling back to synchronous checkpointing."
347+ )
348+ async_save = False
349+
341350 if dataloader is not None :
342351 save_dataloader (
343352 dataloader = dataloader ,
Original file line number Diff line number Diff line change @@ -531,7 +531,6 @@ def test_checkpoint_save_and_load_single_process_fsdp2_cp_fp8_quantized(recipe_p
531531 )
532532
533533
534- @pytest .mark .xfail ()
535534def test_checkpoint_save_and_load_single_process_fsdp2_cp_fp8_quantized_async (recipe_path , tmp_path ):
536535 """Test checkpoint save/resume for FSDP2+CP with FP8 quantized model init and async save.
537536
You can’t perform that action at this time.
0 commit comments