-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Expand file tree
/
Copy pathtest_spike.py
More file actions
95 lines (84 loc) · 4.18 KB
/
test_spike.py
File metadata and controls
95 lines (84 loc) · 4.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import contextlib
import pytest
import torch
from lightning.fabric import Fabric
from lightning.fabric.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_1_0_0
from lightning.fabric.utilities.spike import SpikeDetection, TrainingSpikeException
from tests_fabric.helpers.runif import RunIf
def spike_detection_test(fabric, global_rank_spike, spike_value, should_raise):
loss_vals = [1 / i for i in range(1, 10)]
if fabric.global_rank == global_rank_spike:
if spike_value is None:
loss_vals[4] = 3
else:
loss_vals[4] = spike_value
for i in range(len(loss_vals)):
context = pytest.raises(TrainingSpikeException) if i == 4 and should_raise else contextlib.nullcontext()
with context:
fabric.call(
"on_train_batch_end",
fabric=fabric,
loss=torch.tensor(loss_vals[i], device=fabric.device),
batch=None,
batch_idx=i,
)
@pytest.mark.flaky(reruns=3)
@pytest.mark.parametrize(
("global_rank_spike", "num_devices", "spike_value", "finite_only"),
# NOTE FOR ALL FOLLOWING TESTS:
# adding run on linux only because multiprocessing on other platforms takes forever
[
pytest.param(0, 1, None, True),
pytest.param(0, 1, None, False),
pytest.param(0, 1, float("inf"), True),
pytest.param(0, 1, float("inf"), False),
pytest.param(0, 1, float("-inf"), True),
pytest.param(0, 1, float("-inf"), False),
pytest.param(0, 1, float("NaN"), True),
pytest.param(0, 1, float("NaN"), False),
pytest.param(0, 2, None, True, marks=RunIf(linux_only=True)),
pytest.param(0, 2, None, False, marks=RunIf(linux_only=True)),
pytest.param(1, 2, None, True, marks=RunIf(linux_only=True)),
pytest.param(1, 2, None, False, marks=RunIf(linux_only=True)),
pytest.param(0, 2, float("inf"), True, marks=RunIf(linux_only=True)),
pytest.param(0, 2, float("inf"), False, marks=RunIf(linux_only=True)),
pytest.param(1, 2, float("inf"), True, marks=RunIf(linux_only=True)),
pytest.param(1, 2, float("inf"), False, marks=RunIf(linux_only=True)),
pytest.param(0, 2, float("-inf"), True, marks=RunIf(linux_only=True)),
pytest.param(0, 2, float("-inf"), False, marks=RunIf(linux_only=True)),
pytest.param(1, 2, float("-inf"), True, marks=RunIf(linux_only=True)),
pytest.param(1, 2, float("-inf"), False, marks=RunIf(linux_only=True)),
pytest.param(0, 2, float("NaN"), True, marks=RunIf(linux_only=True)),
pytest.param(0, 2, float("NaN"), False, marks=RunIf(linux_only=True)),
pytest.param(1, 2, float("NaN"), True, marks=RunIf(linux_only=True)),
pytest.param(1, 2, float("NaN"), False, marks=RunIf(linux_only=True)),
],
)
@pytest.mark.skipif(not _TORCHMETRICS_GREATER_EQUAL_1_0_0, reason="requires torchmetrics>=1.0.0")
def test_fabric_spike_detection_integration(tmp_path, global_rank_spike, num_devices, spike_value, finite_only):
fabric = Fabric(
accelerator="cpu",
devices=num_devices,
callbacks=[SpikeDetection(exclude_batches_path=tmp_path, finite_only=finite_only)],
strategy="ddp_spawn",
)
# spike_value == None -> typical spike detection
# finite_only -> typical spike detection and raise with NaN +/- inf
# if inf -> inf >> other values -> typical spike detection
should_raise = spike_value is None or finite_only or spike_value == float("inf")
fabric.launch(
spike_detection_test,
global_rank_spike=global_rank_spike,
spike_value=spike_value,
should_raise=should_raise,
)
@pytest.mark.skipif(not _TORCHMETRICS_GREATER_EQUAL_1_0_0, reason="requires torchmetrics>=1.0.0")
def test_spike_detection_state_dict_roundtrip():
# Regression: ``load_state_dict`` previously called
# ``self.running.load_state_dict(...)`` but the attribute is named
# ``self.running_mean`` (set in ``__init__``), so any resume of a
# SpikeDetection callback crashed with
# ``AttributeError: 'SpikeDetection' object has no attribute 'running'``.
src = SpikeDetection()
dst = SpikeDetection()
dst.load_state_dict(src.state_dict())