11# Copyright The Lightning AI team.
2- #
3- # Licensed under the Apache License, Version 2.0 (the "License");
4- # you may not use this file except in compliance with the License.
5- # You may obtain a copy of the License at
6- #
7- # http://www.apache.org/licenses/LICENSE-2.0
8- #
9- # Unless required by applicable law or agreed to in writing, software
10- # distributed under the License is distributed on an "AS IS" BASIS,
11- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12- # See the License for the specific language governing permissions and
13- # limitations under the License.
2+ # Licensed under the Apache License, Version 2.0
3+
144import os
155import time
166from pathlib import Path
2313from lightning .pytorch .demos import BoringModel
2414from tests_pytorch .helpers .runif import RunIf
2515
26- # --- integration test to verify the checkpoint is actually saved and loaded asynchronously ---
16+ # -------------------------------------------------------------------------
17+ # Helpers
18+ # -------------------------------------------------------------------------
19+
2720
21+ def _sync_across_ranks (trainer , obj ):
22+ """Broadcast an object from rank 0 once the strategy/process group exists."""
23+ trainer .strategy .barrier ()
24+ obj = trainer .strategy .broadcast (obj , src = 0 )
25+ trainer .strategy .barrier ()
26+ return obj
2827
29- def _wait_for_dcp_metadata (path : Path , timeout = 10 ):
30- # writing files in CI can be slow,
31- # and DCP writes a metadata file last,
32- # so we can wait for that to appear to ensure the checkpoint is ready
28+
29+ def _wait_for_dcp_metadata (path : Path , timeout : int = 10 ):
30+ """DCP writes metadata last; wait until it appears."""
3331 start = time .time ()
3432 while True :
35- # DCP metadata file pattern
3633 if any (p .name .startswith (".metadata" ) for p in path .iterdir ()):
3734 return
3835 if time .time () - start > timeout :
3936 raise RuntimeError ("Checkpoint metadata not visible yet" )
4037 time .sleep (0.1 )
4138
4239
40+ def _find_checkpoint (tmp_path : Path ):
41+ """Poll until a checkpoint file exists (async IO may delay visibility)."""
42+ ckpt_dir = tmp_path / "lightning_logs" / "version_0" / "checkpoints"
43+
44+ for _ in range (100 ):
45+ files = list (ckpt_dir .glob ("*.ckpt" ))
46+ if files :
47+ return max (files , key = os .path .getctime )
48+ time .sleep (0.1 )
49+
50+ raise RuntimeError (f"Checkpoint file not found in { ckpt_dir } " )
51+
52+
53+ # -------------------------------------------------------------------------
54+ # Core logic
55+ # -------------------------------------------------------------------------
56+
57+
4358def save_model_checkpoint (tmp_path , expected_strategy_name , accelerator , devices ):
4459 model = BoringModel ()
60+
4561 trainer = Trainer (
4662 default_root_dir = tmp_path ,
4763 max_epochs = 10 ,
4864 devices = devices ,
49- plugins = [DistributedAsyncCheckpointIO ()],
5065 accelerator = accelerator ,
66+ plugins = [DistributedAsyncCheckpointIO ()],
5167 )
68+
5269 assert trainer .strategy .__class__ .__name__ == expected_strategy_name , (
5370 f"Expected strategy { expected_strategy_name } , but got { trainer .strategy .__class__ .__name__ } "
5471 )
72+
5573 trainer .fit (model )
5674
75+ # Important:
76+ # pytest standalone gives each worker a different tmp_path.
77+ # After DDP init (fit), broadcast rank0's path so all ranks agree.
78+ tmp_path = _sync_across_ranks (trainer , tmp_path )
5779
58- def get_checkpoint_path (tmp_path ):
59- tmp_path = Path (tmp_path )
60- ckpt_path = tmp_path / "lightning_logs" / "version_0" / "checkpoints"
61- ckpt_files = list (ckpt_path .glob ("*.ckpt" ))
62- assert len (ckpt_files ) > 0 , "No checkpoint files found"
63- return max (ckpt_files , key = os .path .getctime )
80+ return tmp_path # noqa: RET504
6481
6582
6683def load_model_checkpoint (tmp_path , expected_strategy_name , accelerator , devices ):
67- last_ckpt = get_checkpoint_path (tmp_path )
68-
6984 model = BoringModel ()
85+
7086 trainer = Trainer (
7187 default_root_dir = tmp_path ,
7288 max_epochs = 20 ,
7389 devices = devices ,
74- plugins = [DistributedAsyncCheckpointIO ()],
7590 accelerator = accelerator ,
91+ plugins = [DistributedAsyncCheckpointIO ()],
7692 )
93+
7794 assert trainer .strategy .__class__ .__name__ == expected_strategy_name , (
7895 f"Expected strategy { expected_strategy_name } , but got { trainer .strategy .__class__ .__name__ } "
7996 )
8097
81- trainer .fit (model , ckpt_path = last_ckpt ) # if loading works, it will restore to epoch 10 and continue to 20
98+ last_ckpt = _find_checkpoint (Path (tmp_path ))
99+
100+ trainer .fit (model , ckpt_path = last_ckpt )
101+
102+
103+ # -------------------------------------------------------------------------
104+ # Tests
105+ # -------------------------------------------------------------------------
82106
83107
84108@RunIf (min_torch = "2.4" , min_cuda_gpus = 2 , standalone = True )
@@ -91,20 +115,42 @@ def load_model_checkpoint(tmp_path, expected_strategy_name, accelerator, devices
91115)
92116def test_trainer_distributed_async_checkpointio_integration_cuda (tmp_path , expected_strategy_name , devices ):
93117 torch .manual_seed (1234 )
94- save_model_checkpoint (tmp_path , expected_strategy_name , accelerator = "cuda" , devices = devices )
95118
96- ckpt_path = get_checkpoint_path (tmp_path )
119+ tmp_path = save_model_checkpoint (
120+ tmp_path ,
121+ expected_strategy_name ,
122+ accelerator = "cuda" ,
123+ devices = devices ,
124+ )
125+
126+ ckpt_path = _find_checkpoint (Path (tmp_path ))
97127 _wait_for_dcp_metadata (ckpt_path )
98128
99- load_model_checkpoint (tmp_path , expected_strategy_name , accelerator = "cuda" , devices = devices )
129+ load_model_checkpoint (
130+ tmp_path ,
131+ expected_strategy_name ,
132+ accelerator = "cuda" ,
133+ devices = devices ,
134+ )
100135
101136
102137@RunIf (min_torch = "2.4" , standalone = True )
103138def test_trainer_distributed_async_checkpointio_integration_cpu (tmp_path ):
104139 torch .manual_seed (1234 )
105- save_model_checkpoint (tmp_path , "SingleDeviceStrategy" , accelerator = "cpu" , devices = 1 )
106140
107- ckpt_path = get_checkpoint_path (tmp_path )
141+ save_model_checkpoint (
142+ tmp_path ,
143+ "SingleDeviceStrategy" ,
144+ accelerator = "cpu" ,
145+ devices = 1 ,
146+ )
147+
148+ ckpt_path = _find_checkpoint (Path (tmp_path ))
108149 _wait_for_dcp_metadata (ckpt_path )
109150
110- load_model_checkpoint (tmp_path , "SingleDeviceStrategy" , accelerator = "cpu" , devices = 1 )
151+ load_model_checkpoint (
152+ tmp_path ,
153+ "SingleDeviceStrategy" ,
154+ accelerator = "cpu" ,
155+ devices = 1 ,
156+ )
0 commit comments