Skip to content

Commit f6dbb49

Browse files
committed
fix distributed integration tests
1 parent 31ec560 commit f6dbb49

File tree

2 files changed

+95
-35
lines changed

2 files changed

+95
-35
lines changed

tests/tests_fabric/plugins/io/test_distributed_async_io.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,16 @@ def test_async_checkpointio_storage_options_not_supported(tmp_path):
140140

141141

142142
# --- integration test to verify the checkpoint is actually saved and loaded asynchronously ---
143+
144+
145+
def _broadcast_from_rank0(fabric: Fabric, obj):
146+
"""Broadcast an object from rank0 once Fabric has launched."""
147+
fabric.barrier("pre_broadcast")
148+
obj = fabric.broadcast(obj)
149+
fabric.barrier("post_broadcast")
150+
return obj
151+
152+
143153
class SimpleModel(nn.Module):
144154
def __init__(self):
145155
super().__init__()
@@ -204,7 +214,11 @@ def run_async_checkpoint_state_restoration(tmp_path, expected_strategy_name, acc
204214
# snapshot weights BEFORE save
205215
before = {k: v.detach().clone() for k, v in model.state_dict().items()}
206216
state = AttributeDict(model=model, optimizer=optimizer, step=1)
217+
218+
# rank0 decides canonical checkpoint path
207219
ckpt_path = tmp_path / "ckpt"
220+
ckpt_path = _broadcast_from_rank0(fabric, ckpt_path)
221+
208222
fabric.save(ckpt_path, state)
209223

210224
# Wait for DistributedAsyncCheckpointIO to finish writing checkpoint metadata.
Lines changed: 81 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,6 @@
11
# Copyright The Lightning AI team.
2-
#
3-
# Licensed under the Apache License, Version 2.0 (the "License");
4-
# you may not use this file except in compliance with the License.
5-
# You may obtain a copy of the License at
6-
#
7-
# http://www.apache.org/licenses/LICENSE-2.0
8-
#
9-
# Unless required by applicable law or agreed to in writing, software
10-
# distributed under the License is distributed on an "AS IS" BASIS,
11-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
# See the License for the specific language governing permissions and
13-
# limitations under the License.
2+
# Licensed under the Apache License, Version 2.0
3+
144
import os
155
import time
166
from pathlib import Path
@@ -23,62 +13,96 @@
2313
from lightning.pytorch.demos import BoringModel
2414
from tests_pytorch.helpers.runif import RunIf
2515

26-
# --- integration test to verify the checkpoint is actually saved and loaded asynchronously ---
16+
# -------------------------------------------------------------------------
17+
# Helpers
18+
# -------------------------------------------------------------------------
19+
2720

21+
def _sync_across_ranks(trainer, obj):
22+
"""Broadcast an object from rank 0 once the strategy/process group exists."""
23+
trainer.strategy.barrier()
24+
obj = trainer.strategy.broadcast(obj, src=0)
25+
trainer.strategy.barrier()
26+
return obj
2827

29-
def _wait_for_dcp_metadata(path: Path, timeout=10):
30-
# writing files in CI can be slow,
31-
# and DCP writes a metadata file last,
32-
# so we can wait for that to appear to ensure the checkpoint is ready
28+
29+
def _wait_for_dcp_metadata(path: Path, timeout: int = 10):
30+
"""DCP writes metadata last; wait until it appears."""
3331
start = time.time()
3432
while True:
35-
# DCP metadata file pattern
3633
if any(p.name.startswith(".metadata") for p in path.iterdir()):
3734
return
3835
if time.time() - start > timeout:
3936
raise RuntimeError("Checkpoint metadata not visible yet")
4037
time.sleep(0.1)
4138

4239

40+
def _find_checkpoint(tmp_path: Path):
41+
"""Poll until a checkpoint file exists (async IO may delay visibility)."""
42+
ckpt_dir = tmp_path / "lightning_logs" / "version_0" / "checkpoints"
43+
44+
for _ in range(100):
45+
files = list(ckpt_dir.glob("*.ckpt"))
46+
if files:
47+
return max(files, key=os.path.getctime)
48+
time.sleep(0.1)
49+
50+
raise RuntimeError(f"Checkpoint file not found in {ckpt_dir}")
51+
52+
53+
# -------------------------------------------------------------------------
54+
# Core logic
55+
# -------------------------------------------------------------------------
56+
57+
4358
def save_model_checkpoint(tmp_path, expected_strategy_name, accelerator, devices):
4459
model = BoringModel()
60+
4561
trainer = Trainer(
4662
default_root_dir=tmp_path,
4763
max_epochs=10,
4864
devices=devices,
49-
plugins=[DistributedAsyncCheckpointIO()],
5065
accelerator=accelerator,
66+
plugins=[DistributedAsyncCheckpointIO()],
5167
)
68+
5269
assert trainer.strategy.__class__.__name__ == expected_strategy_name, (
5370
f"Expected strategy {expected_strategy_name}, but got {trainer.strategy.__class__.__name__}"
5471
)
72+
5573
trainer.fit(model)
5674

75+
# Important:
76+
# pytest standalone gives each worker a different tmp_path.
77+
# After DDP init (fit), broadcast rank0's path so all ranks agree.
78+
tmp_path = _sync_across_ranks(trainer, tmp_path)
5779

58-
def get_checkpoint_path(tmp_path):
59-
tmp_path = Path(tmp_path)
60-
ckpt_path = tmp_path / "lightning_logs" / "version_0" / "checkpoints"
61-
ckpt_files = list(ckpt_path.glob("*.ckpt"))
62-
assert len(ckpt_files) > 0, "No checkpoint files found"
63-
return max(ckpt_files, key=os.path.getctime)
80+
return tmp_path # noqa: RET504
6481

6582

6683
def load_model_checkpoint(tmp_path, expected_strategy_name, accelerator, devices):
67-
last_ckpt = get_checkpoint_path(tmp_path)
68-
6984
model = BoringModel()
85+
7086
trainer = Trainer(
7187
default_root_dir=tmp_path,
7288
max_epochs=20,
7389
devices=devices,
74-
plugins=[DistributedAsyncCheckpointIO()],
7590
accelerator=accelerator,
91+
plugins=[DistributedAsyncCheckpointIO()],
7692
)
93+
7794
assert trainer.strategy.__class__.__name__ == expected_strategy_name, (
7895
f"Expected strategy {expected_strategy_name}, but got {trainer.strategy.__class__.__name__}"
7996
)
8097

81-
trainer.fit(model, ckpt_path=last_ckpt) # if loading works, it will restore to epoch 10 and continue to 20
98+
last_ckpt = _find_checkpoint(Path(tmp_path))
99+
100+
trainer.fit(model, ckpt_path=last_ckpt)
101+
102+
103+
# -------------------------------------------------------------------------
104+
# Tests
105+
# -------------------------------------------------------------------------
82106

83107

84108
@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
@@ -91,20 +115,42 @@ def load_model_checkpoint(tmp_path, expected_strategy_name, accelerator, devices
91115
)
92116
def test_trainer_distributed_async_checkpointio_integration_cuda(tmp_path, expected_strategy_name, devices):
93117
torch.manual_seed(1234)
94-
save_model_checkpoint(tmp_path, expected_strategy_name, accelerator="cuda", devices=devices)
95118

96-
ckpt_path = get_checkpoint_path(tmp_path)
119+
tmp_path = save_model_checkpoint(
120+
tmp_path,
121+
expected_strategy_name,
122+
accelerator="cuda",
123+
devices=devices,
124+
)
125+
126+
ckpt_path = _find_checkpoint(Path(tmp_path))
97127
_wait_for_dcp_metadata(ckpt_path)
98128

99-
load_model_checkpoint(tmp_path, expected_strategy_name, accelerator="cuda", devices=devices)
129+
load_model_checkpoint(
130+
tmp_path,
131+
expected_strategy_name,
132+
accelerator="cuda",
133+
devices=devices,
134+
)
100135

101136

102137
@RunIf(min_torch="2.4", standalone=True)
103138
def test_trainer_distributed_async_checkpointio_integration_cpu(tmp_path):
104139
torch.manual_seed(1234)
105-
save_model_checkpoint(tmp_path, "SingleDeviceStrategy", accelerator="cpu", devices=1)
106140

107-
ckpt_path = get_checkpoint_path(tmp_path)
141+
save_model_checkpoint(
142+
tmp_path,
143+
"SingleDeviceStrategy",
144+
accelerator="cpu",
145+
devices=1,
146+
)
147+
148+
ckpt_path = _find_checkpoint(Path(tmp_path))
108149
_wait_for_dcp_metadata(ckpt_path)
109150

110-
load_model_checkpoint(tmp_path, "SingleDeviceStrategy", accelerator="cpu", devices=1)
151+
load_model_checkpoint(
152+
tmp_path,
153+
"SingleDeviceStrategy",
154+
accelerator="cpu",
155+
devices=1,
156+
)

0 commit comments

Comments
 (0)