|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 | import os
|
| 15 | +from copy import deepcopy |
15 | 16 | from pathlib import Path
|
16 | 17 | from typing import Any, Optional
|
17 | 18 |
|
18 | 19 | import pytest
|
19 | 20 | import torch
|
20 | 21 | from torch import Tensor, nn
|
21 | 22 | from torch.optim.swa_utils import get_swa_avg_fn
|
22 |
| -from torch.utils.data import DataLoader |
| 23 | +from torch.utils.data import DataLoader, Dataset |
23 | 24 |
|
24 | 25 | from lightning.pytorch import LightningModule, Trainer
|
25 | 26 | from lightning.pytorch.callbacks import WeightAveraging
|
26 | 27 | from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
|
27 | 28 | from tests_pytorch.helpers.runif import RunIf
|
28 | 29 |
|
29 | 30 |
|
30 |
| -class WeightAveragingTestModel(BoringModel): |
31 |
| - def __init__( |
32 |
| - self, batch_norm: bool = True, iterable_dataset: bool = False, crash_on_epoch: Optional[int] = None |
33 |
| - ) -> None: |
| 31 | +class TestModel(BoringModel): |
| 32 | + def __init__(self, batch_norm: bool = True) -> None: |
34 | 33 | super().__init__()
|
35 | 34 | layers = [nn.Linear(32, 32)]
|
36 | 35 | if batch_norm:
|
37 | 36 | layers.append(nn.BatchNorm1d(32))
|
38 | 37 | layers += [nn.ReLU(), nn.Linear(32, 2)]
|
39 | 38 | self.layer = nn.Sequential(*layers)
|
40 |
| - self.iterable_dataset = iterable_dataset |
41 |
| - self.crash_on_epoch = crash_on_epoch |
| 39 | + self.crash_on_epoch = None |
42 | 40 |
|
43 | 41 | def training_step(self, batch: Tensor, batch_idx: int) -> None:
|
44 | 42 | if self.crash_on_epoch and self.trainer.current_epoch >= self.crash_on_epoch:
|
45 |
| - raise Exception("CRASH TEST") |
| 43 | + raise Exception("CRASH") |
46 | 44 | return super().training_step(batch, batch_idx)
|
47 | 45 |
|
48 |
| - def train_dataloader(self) -> None: |
49 |
| - dataset_class = RandomIterableDataset if self.iterable_dataset else RandomDataset |
50 |
| - return DataLoader(dataset_class(32, 32), batch_size=4) |
51 |
| - |
52 | 46 | def configure_optimizers(self) -> None:
|
53 | 47 | return torch.optim.SGD(self.layer.parameters(), lr=0.1)
|
54 | 48 |
|
@@ -194,95 +188,115 @@ def setup(self, trainer, pl_module, stage) -> None:
|
194 | 188 | @pytest.mark.parametrize("batch_norm", [True, False])
|
195 | 189 | @pytest.mark.parametrize("iterable_dataset", [True, False])
|
196 | 190 | def test_ema(tmp_path, batch_norm: bool, iterable_dataset: bool):
|
197 |
| - _train(tmp_path, EMATestCallback(), batch_norm=batch_norm, iterable_dataset=iterable_dataset) |
| 191 | + model = TestModel(batch_norm=batch_norm) |
| 192 | + dataset = RandomIterableDataset(32, 32) if iterable_dataset else RandomDataset(32, 32) |
| 193 | + _train(model, dataset, tmp_path, EMATestCallback()) |
198 | 194 |
|
199 | 195 |
|
200 | 196 | @pytest.mark.parametrize(
|
201 | 197 | "accelerator", [pytest.param("gpu", marks=RunIf(min_cuda_gpus=1)), pytest.param("mps", marks=RunIf(mps=True))]
|
202 | 198 | )
|
203 | 199 | def test_ema_accelerator(tmp_path, accelerator):
|
204 |
| - _train(tmp_path, EMATestCallback(), accelerator=accelerator, devices=1) |
| 200 | + model = TestModel() |
| 201 | + dataset = RandomDataset(32, 32) |
| 202 | + _train(model, dataset, tmp_path, EMATestCallback(), accelerator=accelerator, devices=1) |
205 | 203 |
|
206 | 204 |
|
207 | 205 | @RunIf(min_cuda_gpus=2, standalone=True)
|
208 | 206 | def test_ema_ddp(tmp_path):
|
209 |
| - _train(tmp_path, EMATestCallback(devices=2), strategy="ddp", accelerator="gpu", devices=2) |
| 207 | + model = TestModel() |
| 208 | + dataset = RandomDataset(32, 32) |
| 209 | + _train(model, dataset, tmp_path, EMATestCallback(devices=2), strategy="ddp", accelerator="gpu", devices=2) |
210 | 210 |
|
211 | 211 |
|
212 | 212 | @RunIf(min_cuda_gpus=2)
|
213 | 213 | def test_ema_ddp_spawn(tmp_path):
|
214 |
| - _train(tmp_path, EMATestCallback(devices=2), strategy="ddp_spawn", accelerator="gpu", devices=2) |
| 214 | + model = TestModel() |
| 215 | + dataset = RandomDataset(32, 32) |
| 216 | + _train(model, dataset, tmp_path, EMATestCallback(devices=2), strategy="ddp_spawn", accelerator="gpu", devices=2) |
215 | 217 |
|
216 | 218 |
|
217 | 219 | @RunIf(skip_windows=True)
|
218 | 220 | def test_ema_ddp_spawn_cpu(tmp_path):
|
219 |
| - _train(tmp_path, EMATestCallback(devices=2), strategy="ddp_spawn", accelerator="cpu", devices=2) |
| 221 | + model = TestModel() |
| 222 | + dataset = RandomDataset(32, 32) |
| 223 | + _train(model, dataset, tmp_path, EMATestCallback(devices=2), strategy="ddp_spawn", accelerator="cpu", devices=2) |
220 | 224 |
|
221 | 225 |
|
222 |
| -@pytest.mark.parametrize("crash_on_epoch", [1, 3]) |
| 226 | +@pytest.mark.parametrize("crash_on_epoch", [1, 3, 5]) |
223 | 227 | def test_ema_resume(tmp_path, crash_on_epoch):
|
224 |
| - _train_and_resume(tmp_path, crash_on_epoch=crash_on_epoch) |
| 228 | + dataset = RandomDataset(32, 32) |
| 229 | + model1 = TestModel() |
| 230 | + model2 = deepcopy(model1) |
| 231 | + |
| 232 | + _train(model1, dataset, tmp_path, EMATestCallback()) |
| 233 | + |
| 234 | + model2.crash_on_epoch = crash_on_epoch |
| 235 | + model2 = _train_and_resume(model2, dataset, tmp_path) |
| 236 | + |
| 237 | + for param1, param2 in zip(model1.parameters(), model2.parameters()): |
| 238 | + assert torch.allclose(param1, param2, atol=0.001) |
225 | 239 |
|
226 | 240 |
|
227 | 241 | @RunIf(skip_windows=True)
|
228 | 242 | def test_ema_resume_ddp(tmp_path):
|
229 |
| - _train_and_resume(tmp_path, crash_on_epoch=3, use_ddp=True) |
| 243 | + model = TestModel() |
| 244 | + model.crash_on_epoch = 3 |
| 245 | + dataset = RandomDataset(32, 32) |
| 246 | + _train_and_resume(model, dataset, tmp_path, strategy="ddp_spawn", devices=2) |
230 | 247 |
|
231 | 248 |
|
232 | 249 | def test_swa(tmp_path):
|
233 |
| - _train(tmp_path, SWATestCallback()) |
| 250 | + model = TestModel() |
| 251 | + dataset = RandomDataset(32, 32) |
| 252 | + _train(model, dataset, tmp_path, SWATestCallback()) |
234 | 253 |
|
235 | 254 |
|
236 | 255 | def _train(
|
| 256 | + model: TestModel, |
| 257 | + dataset: Dataset, |
237 | 258 | tmp_path: str,
|
238 | 259 | callback: WeightAveraging,
|
239 |
| - batch_norm: bool = True, |
240 | 260 | strategy: str = "auto",
|
241 | 261 | accelerator: str = "cpu",
|
242 | 262 | devices: int = 1,
|
243 |
| - iterable_dataset: bool = False, |
244 | 263 | checkpoint_path: Optional[str] = None,
|
245 |
| - crash_on_epoch: Optional[int] = None, |
246 |
| -) -> None: |
| 264 | + will_crash: bool = False, |
| 265 | +) -> TestModel: |
| 266 | + deterministic = accelerator == "cpu" |
247 | 267 | trainer = Trainer(
|
248 |
| - default_root_dir=tmp_path, |
249 |
| - enable_progress_bar=False, |
250 |
| - enable_model_summary=False, |
| 268 | + accelerator=accelerator, |
| 269 | + strategy=strategy, |
| 270 | + devices=devices, |
251 | 271 | logger=False,
|
| 272 | + callbacks=callback, |
252 | 273 | max_epochs=8,
|
253 | 274 | num_sanity_val_steps=0,
|
254 |
| - callbacks=callback, |
| 275 | + enable_checkpointing=will_crash, |
| 276 | + enable_progress_bar=False, |
| 277 | + enable_model_summary=False, |
255 | 278 | accumulate_grad_batches=2,
|
256 |
| - strategy=strategy, |
257 |
| - accelerator=accelerator, |
258 |
| - devices=devices, |
259 |
| - ) |
260 |
| - model = WeightAveragingTestModel( |
261 |
| - batch_norm=batch_norm, iterable_dataset=iterable_dataset, crash_on_epoch=crash_on_epoch |
| 279 | + deterministic=deterministic, |
| 280 | + default_root_dir=tmp_path, |
262 | 281 | )
|
263 |
| - |
264 |
| - if crash_on_epoch is None: |
265 |
| - trainer.fit(model, ckpt_path=checkpoint_path) |
| 282 | + dataloader = DataLoader(dataset, batch_size=4, shuffle=False) |
| 283 | + if will_crash: |
| 284 | + with pytest.raises(Exception, match="CRASH"): |
| 285 | + trainer.fit(model, dataloader, ckpt_path=checkpoint_path) |
266 | 286 | else:
|
267 |
| - with pytest.raises(Exception, match="CRASH TEST"): |
268 |
| - trainer.fit(model, ckpt_path=checkpoint_path) |
269 |
| - |
| 287 | + trainer.fit(model, dataloader, ckpt_path=checkpoint_path) |
270 | 288 | assert trainer.lightning_module == model
|
271 | 289 |
|
272 | 290 |
|
273 |
| -def _train_and_resume(tmp_path: str, crash_on_epoch: int, use_ddp: bool = False) -> None: |
274 |
| - strategy = "ddp_spawn" if use_ddp else "auto" |
275 |
| - devices = 2 if use_ddp else 1 |
276 |
| - |
277 |
| - _train( |
278 |
| - tmp_path, EMATestCallback(devices=devices), strategy=strategy, devices=devices, crash_on_epoch=crash_on_epoch |
279 |
| - ) |
| 291 | +def _train_and_resume(model: TestModel, dataset: Dataset, tmp_path: str, devices: int = 1, **kwargs) -> TestModel: |
| 292 | + _train(model, dataset, tmp_path, EMATestCallback(devices=devices), devices=devices, will_crash=True, **kwargs) |
280 | 293 |
|
281 | 294 | checkpoint_dir = Path(tmp_path) / "checkpoints"
|
282 | 295 | checkpoint_names = os.listdir(checkpoint_dir)
|
283 | 296 | assert len(checkpoint_names) == 1
|
284 | 297 | checkpoint_path = str(checkpoint_dir / checkpoint_names[0])
|
285 | 298 |
|
286 |
| - _train( |
287 |
| - tmp_path, EMATestCallback(devices=devices), strategy=strategy, devices=devices, checkpoint_path=checkpoint_path |
288 |
| - ) |
| 299 | + model = TestModel.load_from_checkpoint(checkpoint_path) |
| 300 | + callback = EMATestCallback(devices=devices) |
| 301 | + _train(model, dataset, tmp_path, callback, devices=devices, checkpoint_path=checkpoint_path, **kwargs) |
| 302 | + return model |
0 commit comments