Skip to content

Commit c6856eb

Browse files
author
Seppo Enarvi
committed
Test that stopping and resuming won't make a difference in the final model
1 parent 51b9a06 commit c6856eb

File tree

2 files changed

+72
-53
lines changed

2 files changed

+72
-53
lines changed

src/lightning/pytorch/callbacks/weight_averaging.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,15 @@ class WeightAveraging(Callback):
5050

5151
def __init__(
5252
self,
53-
device: Optional[Union[torch.device, int]] = torch.device("cpu"),
53+
device: Optional[Union[torch.device, str, int]] = "cpu",
5454
avg_fn: Optional[Callable[[Tensor, Tensor, Union[Tensor, int]], Tensor]] = None,
5555
):
56-
self._device = device
56+
# The default value is a string so that jsonargparse knows how to serialize it.
57+
if isinstance(device, str):
58+
self._device: Optional[Union[torch.device, int]] = torch.device(device)
59+
else:
60+
self._device = device
61+
5762
self._avg_fn = avg_fn
5863
self._average_model: Optional[AveragedModel] = None
5964

tests/tests_pytorch/callbacks/test_weight_averaging.py

+65-51
Original file line numberDiff line numberDiff line change
@@ -12,43 +12,37 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
from copy import deepcopy
1516
from pathlib import Path
1617
from typing import Any, Optional
1718

1819
import pytest
1920
import torch
2021
from torch import Tensor, nn
2122
from torch.optim.swa_utils import get_swa_avg_fn
22-
from torch.utils.data import DataLoader
23+
from torch.utils.data import DataLoader, Dataset
2324

2425
from lightning.pytorch import LightningModule, Trainer
2526
from lightning.pytorch.callbacks import WeightAveraging
2627
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
2728
from tests_pytorch.helpers.runif import RunIf
2829

2930

30-
class WeightAveragingTestModel(BoringModel):
31-
def __init__(
32-
self, batch_norm: bool = True, iterable_dataset: bool = False, crash_on_epoch: Optional[int] = None
33-
) -> None:
31+
class TestModel(BoringModel):
32+
def __init__(self, batch_norm: bool = True) -> None:
3433
super().__init__()
3534
layers = [nn.Linear(32, 32)]
3635
if batch_norm:
3736
layers.append(nn.BatchNorm1d(32))
3837
layers += [nn.ReLU(), nn.Linear(32, 2)]
3938
self.layer = nn.Sequential(*layers)
40-
self.iterable_dataset = iterable_dataset
41-
self.crash_on_epoch = crash_on_epoch
39+
self.crash_on_epoch = None
4240

4341
def training_step(self, batch: Tensor, batch_idx: int) -> None:
4442
if self.crash_on_epoch and self.trainer.current_epoch >= self.crash_on_epoch:
45-
raise Exception("CRASH TEST")
43+
raise Exception("CRASH")
4644
return super().training_step(batch, batch_idx)
4745

48-
def train_dataloader(self) -> None:
49-
dataset_class = RandomIterableDataset if self.iterable_dataset else RandomDataset
50-
return DataLoader(dataset_class(32, 32), batch_size=4)
51-
5246
def configure_optimizers(self) -> None:
5347
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
5448

@@ -194,95 +188,115 @@ def setup(self, trainer, pl_module, stage) -> None:
194188
@pytest.mark.parametrize("batch_norm", [True, False])
195189
@pytest.mark.parametrize("iterable_dataset", [True, False])
196190
def test_ema(tmp_path, batch_norm: bool, iterable_dataset: bool):
197-
_train(tmp_path, EMATestCallback(), batch_norm=batch_norm, iterable_dataset=iterable_dataset)
191+
model = TestModel(batch_norm=batch_norm)
192+
dataset = RandomIterableDataset(32, 32) if iterable_dataset else RandomDataset(32, 32)
193+
_train(model, dataset, tmp_path, EMATestCallback())
198194

199195

200196
@pytest.mark.parametrize(
201197
"accelerator", [pytest.param("gpu", marks=RunIf(min_cuda_gpus=1)), pytest.param("mps", marks=RunIf(mps=True))]
202198
)
203199
def test_ema_accelerator(tmp_path, accelerator):
204-
_train(tmp_path, EMATestCallback(), accelerator=accelerator, devices=1)
200+
model = TestModel()
201+
dataset = RandomDataset(32, 32)
202+
_train(model, dataset, tmp_path, EMATestCallback(), accelerator=accelerator, devices=1)
205203

206204

207205
@RunIf(min_cuda_gpus=2, standalone=True)
208206
def test_ema_ddp(tmp_path):
209-
_train(tmp_path, EMATestCallback(devices=2), strategy="ddp", accelerator="gpu", devices=2)
207+
model = TestModel()
208+
dataset = RandomDataset(32, 32)
209+
_train(model, dataset, tmp_path, EMATestCallback(devices=2), strategy="ddp", accelerator="gpu", devices=2)
210210

211211

212212
@RunIf(min_cuda_gpus=2)
213213
def test_ema_ddp_spawn(tmp_path):
214-
_train(tmp_path, EMATestCallback(devices=2), strategy="ddp_spawn", accelerator="gpu", devices=2)
214+
model = TestModel()
215+
dataset = RandomDataset(32, 32)
216+
_train(model, dataset, tmp_path, EMATestCallback(devices=2), strategy="ddp_spawn", accelerator="gpu", devices=2)
215217

216218

217219
@RunIf(skip_windows=True)
218220
def test_ema_ddp_spawn_cpu(tmp_path):
219-
_train(tmp_path, EMATestCallback(devices=2), strategy="ddp_spawn", accelerator="cpu", devices=2)
221+
model = TestModel()
222+
dataset = RandomDataset(32, 32)
223+
_train(model, dataset, tmp_path, EMATestCallback(devices=2), strategy="ddp_spawn", accelerator="cpu", devices=2)
220224

221225

222-
@pytest.mark.parametrize("crash_on_epoch", [1, 3])
226+
@pytest.mark.parametrize("crash_on_epoch", [1, 3, 5])
223227
def test_ema_resume(tmp_path, crash_on_epoch):
224-
_train_and_resume(tmp_path, crash_on_epoch=crash_on_epoch)
228+
dataset = RandomDataset(32, 32)
229+
model1 = TestModel()
230+
model2 = deepcopy(model1)
231+
232+
_train(model1, dataset, tmp_path, EMATestCallback())
233+
234+
model2.crash_on_epoch = crash_on_epoch
235+
model2 = _train_and_resume(model2, dataset, tmp_path)
236+
237+
for param1, param2 in zip(model1.parameters(), model2.parameters()):
238+
assert torch.allclose(param1, param2, atol=0.001)
225239

226240

227241
@RunIf(skip_windows=True)
228242
def test_ema_resume_ddp(tmp_path):
229-
_train_and_resume(tmp_path, crash_on_epoch=3, use_ddp=True)
243+
model = TestModel()
244+
model.crash_on_epoch = 3
245+
dataset = RandomDataset(32, 32)
246+
_train_and_resume(model, dataset, tmp_path, strategy="ddp_spawn", devices=2)
230247

231248

232249
def test_swa(tmp_path):
233-
_train(tmp_path, SWATestCallback())
250+
model = TestModel()
251+
dataset = RandomDataset(32, 32)
252+
_train(model, dataset, tmp_path, SWATestCallback())
234253

235254

236255
def _train(
256+
model: TestModel,
257+
dataset: Dataset,
237258
tmp_path: str,
238259
callback: WeightAveraging,
239-
batch_norm: bool = True,
240260
strategy: str = "auto",
241261
accelerator: str = "cpu",
242262
devices: int = 1,
243-
iterable_dataset: bool = False,
244263
checkpoint_path: Optional[str] = None,
245-
crash_on_epoch: Optional[int] = None,
246-
) -> None:
264+
will_crash: bool = False,
265+
) -> TestModel:
266+
deterministic = accelerator == "cpu"
247267
trainer = Trainer(
248-
default_root_dir=tmp_path,
249-
enable_progress_bar=False,
250-
enable_model_summary=False,
268+
accelerator=accelerator,
269+
strategy=strategy,
270+
devices=devices,
251271
logger=False,
272+
callbacks=callback,
252273
max_epochs=8,
253274
num_sanity_val_steps=0,
254-
callbacks=callback,
275+
enable_checkpointing=will_crash,
276+
enable_progress_bar=False,
277+
enable_model_summary=False,
255278
accumulate_grad_batches=2,
256-
strategy=strategy,
257-
accelerator=accelerator,
258-
devices=devices,
259-
)
260-
model = WeightAveragingTestModel(
261-
batch_norm=batch_norm, iterable_dataset=iterable_dataset, crash_on_epoch=crash_on_epoch
279+
deterministic=deterministic,
280+
default_root_dir=tmp_path,
262281
)
263-
264-
if crash_on_epoch is None:
265-
trainer.fit(model, ckpt_path=checkpoint_path)
282+
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
283+
if will_crash:
284+
with pytest.raises(Exception, match="CRASH"):
285+
trainer.fit(model, dataloader, ckpt_path=checkpoint_path)
266286
else:
267-
with pytest.raises(Exception, match="CRASH TEST"):
268-
trainer.fit(model, ckpt_path=checkpoint_path)
269-
287+
trainer.fit(model, dataloader, ckpt_path=checkpoint_path)
270288
assert trainer.lightning_module == model
271289

272290

273-
def _train_and_resume(tmp_path: str, crash_on_epoch: int, use_ddp: bool = False) -> None:
274-
strategy = "ddp_spawn" if use_ddp else "auto"
275-
devices = 2 if use_ddp else 1
276-
277-
_train(
278-
tmp_path, EMATestCallback(devices=devices), strategy=strategy, devices=devices, crash_on_epoch=crash_on_epoch
279-
)
291+
def _train_and_resume(model: TestModel, dataset: Dataset, tmp_path: str, devices: int = 1, **kwargs) -> TestModel:
292+
_train(model, dataset, tmp_path, EMATestCallback(devices=devices), devices=devices, will_crash=True, **kwargs)
280293

281294
checkpoint_dir = Path(tmp_path) / "checkpoints"
282295
checkpoint_names = os.listdir(checkpoint_dir)
283296
assert len(checkpoint_names) == 1
284297
checkpoint_path = str(checkpoint_dir / checkpoint_names[0])
285298

286-
_train(
287-
tmp_path, EMATestCallback(devices=devices), strategy=strategy, devices=devices, checkpoint_path=checkpoint_path
288-
)
299+
model = TestModel.load_from_checkpoint(checkpoint_path)
300+
callback = EMATestCallback(devices=devices)
301+
_train(model, dataset, tmp_path, callback, devices=devices, checkpoint_path=checkpoint_path, **kwargs)
302+
return model

0 commit comments

Comments
 (0)